From 30b3af9423d9ed8d151b05653120d35f0c86a5a1 Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Wed, 13 Nov 2024 10:10:18 -1000 Subject: [PATCH 01/11] Simplify Matchless.Let around recursion (#1254) --- .../scala/org/bykn/bosatsu/Matchless.scala | 42 +++++---- .../org/bykn/bosatsu/MatchlessToValue.scala | 89 +++++++++---------- .../bosatsu/codegen/python/PythonGen.scala | 37 +++----- .../org/bykn/bosatsu/MatchlessTests.scala | 27 +++--- 4 files changed, 89 insertions(+), 106 deletions(-) diff --git a/core/src/main/scala/org/bykn/bosatsu/Matchless.scala b/core/src/main/scala/org/bykn/bosatsu/Matchless.scala index 2e30c55fc..4c06e1ad5 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Matchless.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Matchless.scala @@ -13,7 +13,11 @@ object Matchless { // these hold bindings either in the code, or temporary // local ones, note CheapExpr never trigger a side effect sealed trait CheapExpr extends Expr - sealed abstract class FnExpr extends Expr + sealed abstract class FnExpr extends Expr { + def captures: List[Expr] + // this is set if the function is recursive + def recursiveName: Option[Bindable] + } sealed abstract class StrPart object StrPart { @@ -78,7 +82,7 @@ object Matchless { // name is set for recursive (but not tail recursive) methods case class Lambda( captures: List[Expr], - name: Option[Bindable], + recursiveName: Option[Bindable], args: NonEmptyList[Bindable], expr: Expr ) extends FnExpr @@ -92,7 +96,9 @@ object Matchless { name: Bindable, arg: NonEmptyList[Bindable], body: Expr - ) extends FnExpr + ) extends FnExpr { + val recursiveName: Option[Bindable] = Some(name) + } case class Global(pack: PackageName, name: Bindable) extends CheapExpr @@ -108,7 +114,7 @@ object Matchless { // note fn is never an App case class App(fn: Expr, arg: NonEmptyList[Expr]) extends Expr case class Let( - arg: Either[LocalAnon, (Bindable, RecursionKind)], + arg: Either[LocalAnon, Bindable], expr: Expr, in: Expr ) extends Expr @@ -324,12 +330,19 @@ object Matchless { e: TypedExpr[A], rec: RecursionKind, slots: LambdaState - ): F[Expr] = { - lazy val e0 = loop(e, if (rec.isRecursive) slots.inLet(name) else slots) + ): F[Expr] = rec match { case RecursionKind.Recursive => - def letrec(e: Expr): Expr = - Let(Right((name, RecursionKind.Recursive)), e, Local(name)) + lazy val e0 = loop(e, slots.inLet(name)) + def letrec(expr: Expr): Expr = + expr match { + case fn: FnExpr if fn.recursiveName == Some(name) => fn + case fn: FnExpr => + // loops always have a function name + sys.error(s"expected ${fn.recursiveName} == Some($name) in ${e.repr.render(80)} which compiled to $fn") + case _ => + sys.error(s"expected ${e.repr.render(80)} to compile to a function, but got: $expr") + } // this could be tail recursive if (SelfCallKind(name, e) == SelfCallKind.TailCall) { @@ -345,20 +358,20 @@ object Matchless { val (slots1, caps) = slots.inLet(name).lambdaFrees(frees) loop(body, slots1) .map { v => - letrec(LoopFn(caps, name, args, v)) + LoopFn(caps, name, args, v) } + // $COVERAGE-OFF$ case _ => // TODO: I don't think this case should ever happen in real code // but it definitely does in fuzz tests e0.map(letrec) + // $COVERAGE-ON$ } } else { - // otherwise let rec x = fn in x e0.map(letrec) } - case RecursionKind.NonRecursive => e0 + case RecursionKind.NonRecursive => loop(e, slots) } - } def loop(te: TypedExpr[A], slots: LambdaState): F[Expr] = te match { @@ -396,7 +409,7 @@ object Matchless { .mapN(App(_, _)) case TypedExpr.Let(a, e, in, r, _) => (loopLetVal(a, e, r, slots.unname), loop(in, slots)) - .mapN(Let(Right((a, r)), _, _)) + .mapN(Let(Right(a), _, _)) case TypedExpr.Literal(lit, _, _) => Monad[F].pure(Literal(lit)) case TypedExpr.Match(arg, branches, _) => ( @@ -766,8 +779,7 @@ object Matchless { def lets(binds: List[(Bindable, Expr)], in: Expr): Expr = binds.foldRight(in) { case ((b, e), r) => - val arg = Right((b, RecursionKind.NonRecursive)) - Let(arg, e, r) + Let(Right(b), e, r) } def checkLets(binds: List[LocalAnonMut], in: Expr): Expr = diff --git a/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala b/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala index 8ed32afc5..f2a93e3ca 100644 --- a/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala +++ b/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala @@ -94,12 +94,9 @@ object MatchlessToValue { () } - def capture(it: Vector[Value], name: Option[Bindable]): Scope = + def capture(it: Vector[Value]): Scope = Scope( - name match { - case None => Map.empty - case Some(n) => Map((n, locals(n))) - }, + Map.empty, LongMap.empty, MLongMap(), it @@ -353,7 +350,7 @@ object MatchlessToValue { // It may make things go faster // if the caps are really small // or if we can GC things sooner. - val scope1 = scope.capture(caps.map(s => s(scope)), Some(fnName)) + val scope1 = scope.capture(caps.map(s => s(scope))) FnValue { allArgs => var registers: NonEmptyList[Value] = allArgs @@ -393,30 +390,45 @@ object MatchlessToValue { // the locals can be recusive, so we box into Eval for laziness def loop(me: Expr): Scoped[Value] = me match { - case Lambda(caps, name, args, res) => + case Lambda(Nil, None, args, res) => val resFn = loop(res) - - if (caps.isEmpty && name.isEmpty) { - // we can allocate once if there is no closure - val scope1 = Scope.empty() - val fn = FnValue { argV => + // we can allocate once if there is no closure + val scope1 = Scope.empty() + val fn = FnValue { argV => + val scope2 = scope1.letAll(args, argV) + resFn(scope2) + } + Static(fn) + case Lambda(caps, None, args, res) => + val resFn = loop(res) + val capScoped = caps.map(loop).toVector + Dynamic { scope => + val scope1 = scope + .capture(capScoped.map(scoped => scoped(scope))) + + // hopefully optimization/normalization has lifted anything + // that doesn't depend on argV above this lambda + FnValue { argV => val scope2 = scope1.letAll(args, argV) resFn(scope2) } - Static(fn) - } else { - val capScoped = caps.map(loop).toVector - Dynamic { scope => - val scope1 = scope - .capture(capScoped.map(scoped => scoped(scope)), name) - - // hopefully optimization/normalization has lifted anything - // that doesn't depend on argV above this lambda - FnValue { argV => - val scope2 = scope1.letAll(args, argV) - resFn(scope2) - } + } + case Lambda(caps, Some(name), args, res) => + val resFn = loop(res) + val capScoped = caps.map(loop).toVector + Dynamic { scope => + lazy val scope1: Scope = scope + .capture(capScoped.map(scoped => scoped(scope))) + .let(name, Eval.later(fn)) + + // hopefully optimization/normalization has lifted anything + // that doesn't depend on argV above this lambda + lazy val fn = FnValue { argV => + val scope2 = scope1.letAll(args, argV) + resFn(scope2) } + + fn } case LoopFn(caps, thisName, args, body) => val bodyFn = loop(body) @@ -444,34 +456,15 @@ object MatchlessToValue { Applicative[Scoped].map2(exprFn, argsFn) { (fn, args) => fn.applyAll(args) } - case Let(Right((n1, r)), loopFn @ LoopFn(_, n2, _, _), Local(n3)) - if (n1 === n3) && (n1 === n2) && r.isRecursive => - // LoopFn already correctly handles recursion - loop(loopFn) case Let(localOrBind, value, in) => val valueF = loop(value) val inF = loop(in) localOrBind match { - case Right((b, rec)) => - if (rec.isRecursive) { - - inF.withScope { scope => - // this is the only one that should - // use lazy/Eval.later - // we use it to tie the recursive knot - lazy val scope1: Scope = - scope.let(b, vv) - - lazy val vv = Eval.later(valueF(scope1)) - - scope1 - } - } else { - inF.withScope { (scope: Scope) => - val vv = Eval.now(valueF(scope)) - scope.let(b, vv) - } + case Right(b) => + inF.withScope { (scope: Scope) => + val vv = Eval.now(valueF(scope)) + scope.let(b, vv) } case Left(LocalAnon(l)) => inF.withScope { (scope: Scope) => diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala index 459a530e5..f66b54a39 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala @@ -9,7 +9,6 @@ import org.bykn.bosatsu.{ Matchless, Par, Parser, - RecursionKind } import org.bykn.bosatsu.codegen.Idents import org.bykn.bosatsu.rankn.Type @@ -506,7 +505,7 @@ object PythonGen { // if we have a top level let rec with the same name, handle it more cleanly me match { - case Let(Right((n1, RecursionKind.NonRecursive)), inner, Local(n2)) + case Let(Right(n1), inner, Local(n2)) if ((n1 === name) && (n2 === name)) => // we can just bind now at the top level for { @@ -516,12 +515,6 @@ object PythonGen { case _ => ops.loop(inner, None).map(nm := _) } } yield res - case Let(Right((n1, RecursionKind.Recursive)), fn: FnExpr, Local(n2)) - if (n1 === name) && (n2 === name) => - for { - nm <- Env.topLevelName(name) - res <- ops.topFn(nm, fn, None) - } yield res case fn: FnExpr => for { nm <- Env.topLevelName(name) @@ -1723,7 +1716,7 @@ object PythonGen { val inF = loop(in, slotName) localOrBind match { - case Right((b, _)) => + case Right(b) => // for fn, bosatsu doesn't allow bind name // shadowing, so the bind order of the name // doesn't matter @@ -1747,24 +1740,14 @@ object PythonGen { val inF = loop(in, slotName) localOrBind match { - case Right((b, rec)) => - if (rec.isRecursive) { - // value b is in scope first - for { - bi <- Env.bind(b) - v <- loop(notFn, slotName) - ine <- inF - _ <- Env.unbind(b) - } yield ((bi := v).withValue(ine)) - } else { - // value b is in scope after ve - for { - ve <- loop(notFn, slotName) - bi <- Env.bind(b) - ine <- inF - _ <- Env.unbind(b) - } yield ((bi := ve).withValue(ine)) - } + case Right(b) => + // value b is in scope after ve + for { + ve <- loop(notFn, slotName) + bi <- Env.bind(b) + ine <- inF + _ <- Env.unbind(b) + } yield ((bi := ve).withValue(ine)) case Left(LocalAnon(l)) => // anonymous names never shadow (Env.nameForAnon(l), loop(notFn, slotName)) diff --git a/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala b/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala index deb8a7356..36beb8d2a 100644 --- a/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala +++ b/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala @@ -10,8 +10,8 @@ import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ import Identifier.{Bindable, Constructor} import rankn.DataRepr -import cats.implicits._ import org.scalatest.funsuite.AnyFunSuite +import scala.util.Try class MatchlessTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = @@ -52,28 +52,23 @@ class MatchlessTest extends AnyFunSuite { test("matchless.fromLet is pure: f(x) == f(x)") { forAll(genInputs) { case (b, r, t, fn) => - def run(): Matchless.Expr = - Matchless.fromLet(b, r, t)(fn) + def run(): Option[Matchless.Expr] = + // ill-formed inputs can fail + Try(Matchless.fromLet(b, r, t)(fn)).toOption assert(run() == run()) } } - val genMatchlessExpr: Gen[Matchless.Expr] = + lazy val genMatchlessExpr: Gen[Matchless.Expr] = genInputs.map { case (b, r, t, fn) => - Matchless.fromLet(b, r, t)(fn) + // ill-formed inputs can fail + Try(Matchless.fromLet(b, r, t)(fn)).toOption + } + .flatMap { + case Some(e) => Gen.const(e) + case None => genMatchlessExpr } - - test("regressions") { - // this is illegal code, but it shouldn't throw a match error: - val name = Identifier.Name("foo") - val te = TypedExpr.Local(name, rankn.Type.IntType, ()) - // this should not throw - val me = Matchless.fromLet(name, RecursionKind.Recursive, te)( - fnFromTypeEnv(rankn.TypeEnv.empty) - ) - assert(me != null) - } def genNE[A](max: Int, ga: Gen[A]): Gen[NonEmptyList[A]] = for { From 93f9f744ddcfa4ba196f59f98ac6d70975c87560 Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Wed, 13 Nov 2024 13:11:06 -1000 Subject: [PATCH 02/11] simplify Lets more in Matchless, fix fallout in Python (#1255) * simplify Lets more in Matchless, fix fallout in Python * turn down recursion in generated expressions --- .../scala/org/bykn/bosatsu/Matchless.scala | 24 ++++++++++++++++--- .../bosatsu/codegen/python/PythonGen.scala | 21 +++++++++------- .../src/test/scala/org/bykn/bosatsu/Gen.scala | 17 +++++++------ 3 files changed, 43 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/bykn/bosatsu/Matchless.scala b/core/src/main/scala/org/bykn/bosatsu/Matchless.scala index 4c06e1ad5..11f704767 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Matchless.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Matchless.scala @@ -118,6 +118,19 @@ object Matchless { expr: Expr, in: Expr ) extends Expr + + object Let { + def apply(arg: Bindable, expr: Expr, in: Expr): Expr = + // don't create let x = y in x, just return y + if (in == Local(arg)) expr + else Let(Right(arg), expr, in) + + def apply(arg: LocalAnon, expr: Expr, in: Expr): Expr = + // don't create let x = y in x, just return y + if (in == arg) expr + else Let(Left(arg), expr, in) + } + case class LetMut(name: LocalAnonMut, span: Expr) extends Expr case class Literal(lit: Lit) extends CheapExpr @@ -233,7 +246,7 @@ object Matchless { nm <- tmp bound = LocalAnon(nm) res <- fn(bound) - } yield Let(Left(bound), arg, res) + } yield Let(bound, arg, res) } } @@ -409,7 +422,7 @@ object Matchless { .mapN(App(_, _)) case TypedExpr.Let(a, e, in, r, _) => (loopLetVal(a, e, r, slots.unname), loop(in, slots)) - .mapN(Let(Right(a), _, _)) + .mapN(Let(a, _, _)) case TypedExpr.Literal(lit, _, _) => Monad[F].pure(Literal(lit)) case TypedExpr.Match(arg, branches, _) => ( @@ -779,11 +792,16 @@ object Matchless { def lets(binds: List[(Bindable, Expr)], in: Expr): Expr = binds.foldRight(in) { case ((b, e), r) => - Let(Right(b), e, r) + Let(b, e, r) } def checkLets(binds: List[LocalAnonMut], in: Expr): Expr = binds.foldLeft(in) { case (rest, anon) => + // TODO: sometimes we generate code like + // LetMut(x, Always(SetMut(x, y), f)) + // with no side effects in y or f + // this would be better written as + // Let(x, y, f) LetMut(anon, rest) } diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala index f66b54a39..a120457ac 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala @@ -85,7 +85,7 @@ object PythonGen { case other => // $COVERAGE-OFF$ throw new IllegalStateException( - s"unexpected deref: $b with bindings: $other" + s"unexpected deref: $b with bindings: $other, in $this" ) // $COVERAGE-ON$ } @@ -1617,20 +1617,22 @@ object PythonGen { def loop(expr: Expr, slotName: Option[Code.Ident]): Env[ValueLike] = expr match { - case Lambda(captures, _, args, res) => - // we ignore name because python already supports recursion - // we can use topLevelName on makeDefs since they are already - // shadowing in the same rules as bosatsu + case Lambda(captures, recName, args, res) => + val defName = recName match { + case None => Env.newAssignableVar + case Some(n) => Env.bind(n) + } ( args.traverse(Env.topLevelName(_)), + defName, makeSlots(captures, slotName)(loop(res, _)) ) .flatMapN { - case (args, (None, x: Expression)) => + case (args, _, (None, x: Expression)) if recName.isEmpty => Env.pure(Code.Lambda(args.toList, x)) - case (args, (prefix, v)) => + case (args, defName, (prefix, v)) => for { - defName <- Env.newAssignableVar + _ <- recName.fold(Monad[Env].unit)(Env.unbind(_)) defn = Env.makeDef(defName, args, v) block = Code.blockFromList(prefix.toList ::: defn :: Nil) } yield block.withValue(defName) @@ -1649,7 +1651,7 @@ object PythonGen { } for { - nameI <- Env.deref(thisName) + nameI <- Env.bind(thisName) as <- boundA subs <- subsA (prefix, body) <- makeSlots(captures, slotName)(loop(body, _)) @@ -1657,6 +1659,7 @@ object PythonGen { loopRes <- Env.buildLoop(nameI, subs1, body) // we have bound the args twice: once as args, once as interal muts _ <- subs.traverse_ { case (a, _) => Env.unbind(a) } + _ <- Env.unbind(thisName) } yield Code .blockFromList(prefix.toList :+ loopRes) .withValue(nameI) diff --git a/core/src/test/scala/org/bykn/bosatsu/Gen.scala b/core/src/test/scala/org/bykn/bosatsu/Gen.scala index ef069e8e5..9920941df 100644 --- a/core/src/test/scala/org/bykn/bosatsu/Gen.scala +++ b/core/src/test/scala/org/bykn/bosatsu/Gen.scala @@ -677,6 +677,12 @@ object Generators { useAnnotation = useAnnotation ) + val genRecursionKind: Gen[RecursionKind] = + Gen.frequency( + (20, Gen.const(RecursionKind.NonRecursive)), + (1, Gen.const(RecursionKind.Recursive)) + ) + def matchGen( argGen0: Gen[NonBinding], bodyGen: Gen[Declaration] @@ -696,10 +702,7 @@ object Generators { for { cnt <- Gen.choose(1, 2) - kind <- Gen.frequency( - (10, Gen.const(RecursionKind.NonRecursive)), - (1, Gen.const(RecursionKind.Recursive)) - ) + kind <- genRecursionKind expr <- argGen cases <- optIndent(nonEmptyN(genCase, cnt)) } yield Match(kind, expr, cases)(emptyRegion) @@ -1401,7 +1404,7 @@ object Generators { bindIdentGen, recurse, recurse, - Gen.oneOf(RecursionKind.NonRecursive, RecursionKind.Recursive), + genRecursionKind, genTag ) .map { case (n, ex, in, rec, tag) => @@ -1618,7 +1621,7 @@ object Generators { ) val oneLet = Gen.zip( bindIdentGen.filter(b => !exts(b)), - Gen.oneOf(RecursionKind.NonRecursive, RecursionKind.Recursive), + genRecursionKind, genTypedExpr(genA, 4, theseTypes) ) @@ -1782,7 +1785,7 @@ object Generators { bindIdentGen, recur, recur, - Gen.oneOf(RecursionKind.Recursive, RecursionKind.NonRecursive), + genRecursionKind, genA ) .map { case (a, e, in, r, t) => Let(a, e, in, r, t) } From fb3cc7e16febda78b76021e3f461b95dc20378c8 Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Fri, 15 Nov 2024 12:52:57 -1000 Subject: [PATCH 03/11] Add utils to test matchless compilation (#1257) --- .../bykn/bosatsu/MatchlessFromTypedExpr.scala | 3 +- .../scala/org/bykn/bosatsu/PackageMap.scala | 19 +++++++++++- .../org/bykn/bosatsu/MatchlessTests.scala | 11 +++++++ .../scala/org/bykn/bosatsu/TestUtils.scala | 29 +++++++++++++++++++ 4 files changed, 60 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/bykn/bosatsu/MatchlessFromTypedExpr.scala b/core/src/main/scala/org/bykn/bosatsu/MatchlessFromTypedExpr.scala index a467cecbf..2ecfd87fc 100644 --- a/core/src/main/scala/org/bykn/bosatsu/MatchlessFromTypedExpr.scala +++ b/core/src/main/scala/org/bykn/bosatsu/MatchlessFromTypedExpr.scala @@ -5,10 +5,11 @@ import Identifier.Bindable import cats.implicits._ object MatchlessFromTypedExpr { + type Compiled = Map[PackageName, List[(Bindable, Matchless.Expr)]] // compile a set of packages given a set of external remappings def compile[A]( pm: PackageMap.Typed[A] - )(implicit ec: Par.EC): Map[PackageName, List[(Bindable, Matchless.Expr)]] = { + )(implicit ec: Par.EC): Compiled = { val gdr = pm.getDataRepr diff --git a/core/src/main/scala/org/bykn/bosatsu/PackageMap.scala b/core/src/main/scala/org/bykn/bosatsu/PackageMap.scala index 51f635405..aca53813e 100644 --- a/core/src/main/scala/org/bykn/bosatsu/PackageMap.scala +++ b/core/src/main/scala/org/bykn/bosatsu/PackageMap.scala @@ -1,6 +1,6 @@ package org.bykn.bosatsu -import org.bykn.bosatsu.graph.Memoize +import org.bykn.bosatsu.graph.{Memoize, Toposort} import cats.{Foldable, Monad, Show} import cats.data.{ Ior, @@ -49,6 +49,23 @@ case class PackageMap[A, B, C, +D]( toMap.iterator.map { case (name, pack) => (name, ev(pack).externalDefs) }.toMap + + def topoSort( + ev: Package[A, B, C, D] <:< Package.Typed[Any] + ): Toposort.Result[PackageName] = { + + val packNames = toMap.keys.iterator.toList.sorted + + def nfn(p: PackageName): List[PackageName] = + toMap.get(p) match { + case None => Nil + case Some(pack) => + val tpack = ev(pack) + tpack.imports.map(_.pack.name).sorted + } + + Toposort.sort(packNames)(nfn) + } } object PackageMap { diff --git a/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala b/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala index 36beb8d2a..6c5014b57 100644 --- a/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala +++ b/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala @@ -171,4 +171,15 @@ class MatchlessTest extends AnyFunSuite { case _ => () } } + + test("check compilation of some matchless") { + TestUtils.checkMatchless(""" +x = 1 +""") { binds => + val map = binds.toMap + + assert(map.contains(Identifier.Name("x"))) + assert(map(Identifier.Name("x")) == Matchless.Literal(Lit(1))) + } + } } diff --git a/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala b/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala index 7cf9d5690..c4593e097 100644 --- a/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala +++ b/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala @@ -99,6 +99,35 @@ object TestUtils { } } + def checkMatchless[A]( + statement: String + )(fn: List[(Identifier.Bindable, Matchless.Expr)] => A): A = { + val stmts = Parser.unsafeParse(Statement.parser, statement) + Package.inferBody(testPackage, Nil, stmts).strictToValidated match { + case Validated.Invalid(errs) => + val lm = LocationMap(statement) + val packMap = Map((testPackage, (lm, statement))) + val msg = errs.toList + .map { err => + err.message(packMap, LocationMap.Colorize.None) + } + .mkString("", "\n==========\n", "\n") + sys.error("inference failure: " + msg) + case Validated.Valid(program) => + // make sure all the TypedExpr are valid + program.lets.foreach { case (_, _, te) => assertValid(te) } + val pack: Package.Typed[Declaration] = Package(testPackage, Nil, Nil, (program, ImportMap.empty)) + val pm: PackageMap.Typed[Declaration] = PackageMap.empty + pack + PackageMap.predefCompiled + val srv = Par.newService() + try { + implicit val ec = Par.ecFromService(srv) + val comp = MatchlessFromTypedExpr.compile(pm) + fn(comp(testPackage)) + } + finally Par.shutdownService(srv) + } + } + def makeInputArgs(files: List[(Int, Any)]): List[String] = ("--package_root" :: Int.MaxValue.toString :: Nil) ::: files.flatMap { case (idx, _) => "--input" :: idx.toString :: Nil From 0c5ece34112a65b7c7d4bc847af271d02aae8362 Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Sun, 17 Nov 2024 09:47:34 -1000 Subject: [PATCH 04/11] Generating C code from Matchless (#1253) * checkpoint * checkpoint more work * simplify * implement fns * checkpoint most of c tranlation implemented * fill in more impls * ignore compiled c code * work on string representation * implement moar * code complete * add a basic, failing, test * get some tests passing * Add foldLeft * add reverse_concat * optimize ifThenElseV --- .gitignore | 4 +- c_runtime/bosatsu_generated.h | 192 +---- c_runtime/bosatsu_runtime.c | 96 ++- c_runtime/bosatsu_runtime.h | 78 ++ c_runtime/typegen.py | 6 +- .../scala/org/bykn/bosatsu/Matchless.scala | 9 +- .../org/bykn/bosatsu/codegen/Idents.scala | 25 +- .../bykn/bosatsu/codegen/clang/ClangGen.scala | 814 ++++++++++++++++++ .../org/bykn/bosatsu/codegen/clang/Code.scala | 251 +++++- .../org/bykn/bosatsu/MatchlessTests.scala | 2 +- .../scala/org/bykn/bosatsu/TestUtils.scala | 4 +- .../org/bykn/bosatsu/codegen/IdentsTest.scala | 10 + .../bosatsu/codegen/clang/ClangGenTest.scala | 131 +++ 13 files changed, 1396 insertions(+), 226 deletions(-) create mode 100644 core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala create mode 100644 core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala diff --git a/.gitignore b/.gitignore index c877a8163..1f3cd8628 100644 --- a/.gitignore +++ b/.gitignore @@ -45,4 +45,6 @@ node_modules/ .bloop project/metals.sbt project/project/* -jsui/bosatsu_ui.js \ No newline at end of file +jsui/bosatsu_ui.js + +c_runtime/bosatsu_runtime.o \ No newline at end of file diff --git a/c_runtime/bosatsu_generated.h b/c_runtime/bosatsu_generated.h index d108565fe..58f6cca1b 100644 --- a/c_runtime/bosatsu_generated.h +++ b/c_runtime/bosatsu_generated.h @@ -2916,7 +2916,7 @@ BValue alloc_closure1(size_t size, BValue* data, BClosure1 fn) { } BValue call_fn1(BValue fn, BValue arg0) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn1 pfn = (BPureFn1)TO_POINTER(fn); return pfn(arg0); } @@ -2928,10 +2928,6 @@ BValue call_fn1(BValue fn, BValue arg0) { } } -BValue value_from_pure_fn1(BPureFn1 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure2Data, BClosure2 fn; size_t slot_len;); @@ -2949,7 +2945,7 @@ BValue alloc_closure2(size_t size, BValue* data, BClosure2 fn) { } BValue call_fn2(BValue fn, BValue arg0, BValue arg1) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn2 pfn = (BPureFn2)TO_POINTER(fn); return pfn(arg0, arg1); } @@ -2961,10 +2957,6 @@ BValue call_fn2(BValue fn, BValue arg0, BValue arg1) { } } -BValue value_from_pure_fn2(BPureFn2 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure3Data, BClosure3 fn; size_t slot_len;); @@ -2982,7 +2974,7 @@ BValue alloc_closure3(size_t size, BValue* data, BClosure3 fn) { } BValue call_fn3(BValue fn, BValue arg0, BValue arg1, BValue arg2) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn3 pfn = (BPureFn3)TO_POINTER(fn); return pfn(arg0, arg1, arg2); } @@ -2994,10 +2986,6 @@ BValue call_fn3(BValue fn, BValue arg0, BValue arg1, BValue arg2) { } } -BValue value_from_pure_fn3(BPureFn3 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure4Data, BClosure4 fn; size_t slot_len;); @@ -3015,7 +3003,7 @@ BValue alloc_closure4(size_t size, BValue* data, BClosure4 fn) { } BValue call_fn4(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn4 pfn = (BPureFn4)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3); } @@ -3027,10 +3015,6 @@ BValue call_fn4(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3) { } } -BValue value_from_pure_fn4(BPureFn4 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure5Data, BClosure5 fn; size_t slot_len;); @@ -3048,7 +3032,7 @@ BValue alloc_closure5(size_t size, BValue* data, BClosure5 fn) { } BValue call_fn5(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn5 pfn = (BPureFn5)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4); } @@ -3060,10 +3044,6 @@ BValue call_fn5(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, B } } -BValue value_from_pure_fn5(BPureFn5 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure6Data, BClosure6 fn; size_t slot_len;); @@ -3081,7 +3061,7 @@ BValue alloc_closure6(size_t size, BValue* data, BClosure6 fn) { } BValue call_fn6(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn6 pfn = (BPureFn6)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5); } @@ -3093,10 +3073,6 @@ BValue call_fn6(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, B } } -BValue value_from_pure_fn6(BPureFn6 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure7Data, BClosure7 fn; size_t slot_len;); @@ -3114,7 +3090,7 @@ BValue alloc_closure7(size_t size, BValue* data, BClosure7 fn) { } BValue call_fn7(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5, BValue arg6) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn7 pfn = (BPureFn7)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5, arg6); } @@ -3126,10 +3102,6 @@ BValue call_fn7(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, B } } -BValue value_from_pure_fn7(BPureFn7 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure8Data, BClosure8 fn; size_t slot_len;); @@ -3147,7 +3119,7 @@ BValue alloc_closure8(size_t size, BValue* data, BClosure8 fn) { } BValue call_fn8(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5, BValue arg6, BValue arg7) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn8 pfn = (BPureFn8)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7); } @@ -3159,10 +3131,6 @@ BValue call_fn8(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, B } } -BValue value_from_pure_fn8(BPureFn8 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure9Data, BClosure9 fn; size_t slot_len;); @@ -3180,7 +3148,7 @@ BValue alloc_closure9(size_t size, BValue* data, BClosure9 fn) { } BValue call_fn9(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5, BValue arg6, BValue arg7, BValue arg8) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn9 pfn = (BPureFn9)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8); } @@ -3192,10 +3160,6 @@ BValue call_fn9(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, B } } -BValue value_from_pure_fn9(BPureFn9 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure10Data, BClosure10 fn; size_t slot_len;); @@ -3213,7 +3177,7 @@ BValue alloc_closure10(size_t size, BValue* data, BClosure10 fn) { } BValue call_fn10(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5, BValue arg6, BValue arg7, BValue arg8, BValue arg9) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn10 pfn = (BPureFn10)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9); } @@ -3225,10 +3189,6 @@ BValue call_fn10(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, } } -BValue value_from_pure_fn10(BPureFn10 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure11Data, BClosure11 fn; size_t slot_len;); @@ -3246,7 +3206,7 @@ BValue alloc_closure11(size_t size, BValue* data, BClosure11 fn) { } BValue call_fn11(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5, BValue arg6, BValue arg7, BValue arg8, BValue arg9, BValue arg10) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn11 pfn = (BPureFn11)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10); } @@ -3258,10 +3218,6 @@ BValue call_fn11(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, } } -BValue value_from_pure_fn11(BPureFn11 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure12Data, BClosure12 fn; size_t slot_len;); @@ -3279,7 +3235,7 @@ BValue alloc_closure12(size_t size, BValue* data, BClosure12 fn) { } BValue call_fn12(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5, BValue arg6, BValue arg7, BValue arg8, BValue arg9, BValue arg10, BValue arg11) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn12 pfn = (BPureFn12)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11); } @@ -3291,10 +3247,6 @@ BValue call_fn12(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, } } -BValue value_from_pure_fn12(BPureFn12 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure13Data, BClosure13 fn; size_t slot_len;); @@ -3312,7 +3264,7 @@ BValue alloc_closure13(size_t size, BValue* data, BClosure13 fn) { } BValue call_fn13(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5, BValue arg6, BValue arg7, BValue arg8, BValue arg9, BValue arg10, BValue arg11, BValue arg12) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn13 pfn = (BPureFn13)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12); } @@ -3324,10 +3276,6 @@ BValue call_fn13(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, } } -BValue value_from_pure_fn13(BPureFn13 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure14Data, BClosure14 fn; size_t slot_len;); @@ -3345,7 +3293,7 @@ BValue alloc_closure14(size_t size, BValue* data, BClosure14 fn) { } BValue call_fn14(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5, BValue arg6, BValue arg7, BValue arg8, BValue arg9, BValue arg10, BValue arg11, BValue arg12, BValue arg13) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn14 pfn = (BPureFn14)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13); } @@ -3357,10 +3305,6 @@ BValue call_fn14(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, } } -BValue value_from_pure_fn14(BPureFn14 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure15Data, BClosure15 fn; size_t slot_len;); @@ -3378,7 +3322,7 @@ BValue alloc_closure15(size_t size, BValue* data, BClosure15 fn) { } BValue call_fn15(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5, BValue arg6, BValue arg7, BValue arg8, BValue arg9, BValue arg10, BValue arg11, BValue arg12, BValue arg13, BValue arg14) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn15 pfn = (BPureFn15)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14); } @@ -3390,10 +3334,6 @@ BValue call_fn15(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, } } -BValue value_from_pure_fn15(BPureFn15 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure16Data, BClosure16 fn; size_t slot_len;); @@ -3411,7 +3351,7 @@ BValue alloc_closure16(size_t size, BValue* data, BClosure16 fn) { } BValue call_fn16(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5, BValue arg6, BValue arg7, BValue arg8, BValue arg9, BValue arg10, BValue arg11, BValue arg12, BValue arg13, BValue arg14, BValue arg15) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn16 pfn = (BPureFn16)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15); } @@ -3423,10 +3363,6 @@ BValue call_fn16(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, } } -BValue value_from_pure_fn16(BPureFn16 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure17Data, BClosure17 fn; size_t slot_len;); @@ -3444,7 +3380,7 @@ BValue alloc_closure17(size_t size, BValue* data, BClosure17 fn) { } BValue call_fn17(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5, BValue arg6, BValue arg7, BValue arg8, BValue arg9, BValue arg10, BValue arg11, BValue arg12, BValue arg13, BValue arg14, BValue arg15, BValue arg16) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn17 pfn = (BPureFn17)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15, arg16); } @@ -3456,10 +3392,6 @@ BValue call_fn17(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, } } -BValue value_from_pure_fn17(BPureFn17 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure18Data, BClosure18 fn; size_t slot_len;); @@ -3477,7 +3409,7 @@ BValue alloc_closure18(size_t size, BValue* data, BClosure18 fn) { } BValue call_fn18(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5, BValue arg6, BValue arg7, BValue arg8, BValue arg9, BValue arg10, BValue arg11, BValue arg12, BValue arg13, BValue arg14, BValue arg15, BValue arg16, BValue arg17) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn18 pfn = (BPureFn18)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15, arg16, arg17); } @@ -3489,10 +3421,6 @@ BValue call_fn18(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, } } -BValue value_from_pure_fn18(BPureFn18 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure19Data, BClosure19 fn; size_t slot_len;); @@ -3510,7 +3438,7 @@ BValue alloc_closure19(size_t size, BValue* data, BClosure19 fn) { } BValue call_fn19(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5, BValue arg6, BValue arg7, BValue arg8, BValue arg9, BValue arg10, BValue arg11, BValue arg12, BValue arg13, BValue arg14, BValue arg15, BValue arg16, BValue arg17, BValue arg18) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn19 pfn = (BPureFn19)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15, arg16, arg17, arg18); } @@ -3522,10 +3450,6 @@ BValue call_fn19(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, } } -BValue value_from_pure_fn19(BPureFn19 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure20Data, BClosure20 fn; size_t slot_len;); @@ -3543,7 +3467,7 @@ BValue alloc_closure20(size_t size, BValue* data, BClosure20 fn) { } BValue call_fn20(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5, BValue arg6, BValue arg7, BValue arg8, BValue arg9, BValue arg10, BValue arg11, BValue arg12, BValue arg13, BValue arg14, BValue arg15, BValue arg16, BValue arg17, BValue arg18, BValue arg19) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn20 pfn = (BPureFn20)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15, arg16, arg17, arg18, arg19); } @@ -3555,10 +3479,6 @@ BValue call_fn20(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, } } -BValue value_from_pure_fn20(BPureFn20 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure21Data, BClosure21 fn; size_t slot_len;); @@ -3576,7 +3496,7 @@ BValue alloc_closure21(size_t size, BValue* data, BClosure21 fn) { } BValue call_fn21(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5, BValue arg6, BValue arg7, BValue arg8, BValue arg9, BValue arg10, BValue arg11, BValue arg12, BValue arg13, BValue arg14, BValue arg15, BValue arg16, BValue arg17, BValue arg18, BValue arg19, BValue arg20) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn21 pfn = (BPureFn21)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15, arg16, arg17, arg18, arg19, arg20); } @@ -3588,10 +3508,6 @@ BValue call_fn21(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, } } -BValue value_from_pure_fn21(BPureFn21 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure22Data, BClosure22 fn; size_t slot_len;); @@ -3609,7 +3525,7 @@ BValue alloc_closure22(size_t size, BValue* data, BClosure22 fn) { } BValue call_fn22(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5, BValue arg6, BValue arg7, BValue arg8, BValue arg9, BValue arg10, BValue arg11, BValue arg12, BValue arg13, BValue arg14, BValue arg15, BValue arg16, BValue arg17, BValue arg18, BValue arg19, BValue arg20, BValue arg21) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn22 pfn = (BPureFn22)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15, arg16, arg17, arg18, arg19, arg20, arg21); } @@ -3621,10 +3537,6 @@ BValue call_fn22(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, } } -BValue value_from_pure_fn22(BPureFn22 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure23Data, BClosure23 fn; size_t slot_len;); @@ -3642,7 +3554,7 @@ BValue alloc_closure23(size_t size, BValue* data, BClosure23 fn) { } BValue call_fn23(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5, BValue arg6, BValue arg7, BValue arg8, BValue arg9, BValue arg10, BValue arg11, BValue arg12, BValue arg13, BValue arg14, BValue arg15, BValue arg16, BValue arg17, BValue arg18, BValue arg19, BValue arg20, BValue arg21, BValue arg22) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn23 pfn = (BPureFn23)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15, arg16, arg17, arg18, arg19, arg20, arg21, arg22); } @@ -3654,10 +3566,6 @@ BValue call_fn23(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, } } -BValue value_from_pure_fn23(BPureFn23 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure24Data, BClosure24 fn; size_t slot_len;); @@ -3675,7 +3583,7 @@ BValue alloc_closure24(size_t size, BValue* data, BClosure24 fn) { } BValue call_fn24(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5, BValue arg6, BValue arg7, BValue arg8, BValue arg9, BValue arg10, BValue arg11, BValue arg12, BValue arg13, BValue arg14, BValue arg15, BValue arg16, BValue arg17, BValue arg18, BValue arg19, BValue arg20, BValue arg21, BValue arg22, BValue arg23) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn24 pfn = (BPureFn24)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15, arg16, arg17, arg18, arg19, arg20, arg21, arg22, arg23); } @@ -3687,10 +3595,6 @@ BValue call_fn24(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, } } -BValue value_from_pure_fn24(BPureFn24 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure25Data, BClosure25 fn; size_t slot_len;); @@ -3708,7 +3612,7 @@ BValue alloc_closure25(size_t size, BValue* data, BClosure25 fn) { } BValue call_fn25(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5, BValue arg6, BValue arg7, BValue arg8, BValue arg9, BValue arg10, BValue arg11, BValue arg12, BValue arg13, BValue arg14, BValue arg15, BValue arg16, BValue arg17, BValue arg18, BValue arg19, BValue arg20, BValue arg21, BValue arg22, BValue arg23, BValue arg24) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn25 pfn = (BPureFn25)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15, arg16, arg17, arg18, arg19, arg20, arg21, arg22, arg23, arg24); } @@ -3720,10 +3624,6 @@ BValue call_fn25(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, } } -BValue value_from_pure_fn25(BPureFn25 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure26Data, BClosure26 fn; size_t slot_len;); @@ -3741,7 +3641,7 @@ BValue alloc_closure26(size_t size, BValue* data, BClosure26 fn) { } BValue call_fn26(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5, BValue arg6, BValue arg7, BValue arg8, BValue arg9, BValue arg10, BValue arg11, BValue arg12, BValue arg13, BValue arg14, BValue arg15, BValue arg16, BValue arg17, BValue arg18, BValue arg19, BValue arg20, BValue arg21, BValue arg22, BValue arg23, BValue arg24, BValue arg25) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn26 pfn = (BPureFn26)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15, arg16, arg17, arg18, arg19, arg20, arg21, arg22, arg23, arg24, arg25); } @@ -3753,10 +3653,6 @@ BValue call_fn26(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, } } -BValue value_from_pure_fn26(BPureFn26 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure27Data, BClosure27 fn; size_t slot_len;); @@ -3774,7 +3670,7 @@ BValue alloc_closure27(size_t size, BValue* data, BClosure27 fn) { } BValue call_fn27(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5, BValue arg6, BValue arg7, BValue arg8, BValue arg9, BValue arg10, BValue arg11, BValue arg12, BValue arg13, BValue arg14, BValue arg15, BValue arg16, BValue arg17, BValue arg18, BValue arg19, BValue arg20, BValue arg21, BValue arg22, BValue arg23, BValue arg24, BValue arg25, BValue arg26) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn27 pfn = (BPureFn27)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15, arg16, arg17, arg18, arg19, arg20, arg21, arg22, arg23, arg24, arg25, arg26); } @@ -3786,10 +3682,6 @@ BValue call_fn27(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, } } -BValue value_from_pure_fn27(BPureFn27 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure28Data, BClosure28 fn; size_t slot_len;); @@ -3807,7 +3699,7 @@ BValue alloc_closure28(size_t size, BValue* data, BClosure28 fn) { } BValue call_fn28(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5, BValue arg6, BValue arg7, BValue arg8, BValue arg9, BValue arg10, BValue arg11, BValue arg12, BValue arg13, BValue arg14, BValue arg15, BValue arg16, BValue arg17, BValue arg18, BValue arg19, BValue arg20, BValue arg21, BValue arg22, BValue arg23, BValue arg24, BValue arg25, BValue arg26, BValue arg27) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn28 pfn = (BPureFn28)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15, arg16, arg17, arg18, arg19, arg20, arg21, arg22, arg23, arg24, arg25, arg26, arg27); } @@ -3819,10 +3711,6 @@ BValue call_fn28(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, } } -BValue value_from_pure_fn28(BPureFn28 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure29Data, BClosure29 fn; size_t slot_len;); @@ -3840,7 +3728,7 @@ BValue alloc_closure29(size_t size, BValue* data, BClosure29 fn) { } BValue call_fn29(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5, BValue arg6, BValue arg7, BValue arg8, BValue arg9, BValue arg10, BValue arg11, BValue arg12, BValue arg13, BValue arg14, BValue arg15, BValue arg16, BValue arg17, BValue arg18, BValue arg19, BValue arg20, BValue arg21, BValue arg22, BValue arg23, BValue arg24, BValue arg25, BValue arg26, BValue arg27, BValue arg28) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn29 pfn = (BPureFn29)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15, arg16, arg17, arg18, arg19, arg20, arg21, arg22, arg23, arg24, arg25, arg26, arg27, arg28); } @@ -3852,10 +3740,6 @@ BValue call_fn29(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, } } -BValue value_from_pure_fn29(BPureFn29 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure30Data, BClosure30 fn; size_t slot_len;); @@ -3873,7 +3757,7 @@ BValue alloc_closure30(size_t size, BValue* data, BClosure30 fn) { } BValue call_fn30(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5, BValue arg6, BValue arg7, BValue arg8, BValue arg9, BValue arg10, BValue arg11, BValue arg12, BValue arg13, BValue arg14, BValue arg15, BValue arg16, BValue arg17, BValue arg18, BValue arg19, BValue arg20, BValue arg21, BValue arg22, BValue arg23, BValue arg24, BValue arg25, BValue arg26, BValue arg27, BValue arg28, BValue arg29) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn30 pfn = (BPureFn30)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15, arg16, arg17, arg18, arg19, arg20, arg21, arg22, arg23, arg24, arg25, arg26, arg27, arg28, arg29); } @@ -3885,10 +3769,6 @@ BValue call_fn30(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, } } -BValue value_from_pure_fn30(BPureFn30 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure31Data, BClosure31 fn; size_t slot_len;); @@ -3906,7 +3786,7 @@ BValue alloc_closure31(size_t size, BValue* data, BClosure31 fn) { } BValue call_fn31(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5, BValue arg6, BValue arg7, BValue arg8, BValue arg9, BValue arg10, BValue arg11, BValue arg12, BValue arg13, BValue arg14, BValue arg15, BValue arg16, BValue arg17, BValue arg18, BValue arg19, BValue arg20, BValue arg21, BValue arg22, BValue arg23, BValue arg24, BValue arg25, BValue arg26, BValue arg27, BValue arg28, BValue arg29, BValue arg30) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn31 pfn = (BPureFn31)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15, arg16, arg17, arg18, arg19, arg20, arg21, arg22, arg23, arg24, arg25, arg26, arg27, arg28, arg29, arg30); } @@ -3918,10 +3798,6 @@ BValue call_fn31(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, } } -BValue value_from_pure_fn31(BPureFn31 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - DEFINE_RC_STRUCT(Closure32Data, BClosure32 fn; size_t slot_len;); @@ -3939,7 +3815,7 @@ BValue alloc_closure32(size_t size, BValue* data, BClosure32 fn) { } BValue call_fn32(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, BValue arg4, BValue arg5, BValue arg6, BValue arg7, BValue arg8, BValue arg9, BValue arg10, BValue arg11, BValue arg12, BValue arg13, BValue arg14, BValue arg15, BValue arg16, BValue arg17, BValue arg18, BValue arg19, BValue arg20, BValue arg21, BValue arg22, BValue arg23, BValue arg24, BValue arg25, BValue arg26, BValue arg27, BValue arg28, BValue arg29, BValue arg30, BValue arg31) { - if (IS_STATIC_VALUE(fn)) { + if (IS_PURE_VALUE(fn)) { BPureFn32 pfn = (BPureFn32)TO_POINTER(fn); return pfn(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15, arg16, arg17, arg18, arg19, arg20, arg21, arg22, arg23, arg24, arg25, arg26, arg27, arg28, arg29, arg30, arg31); } @@ -3951,7 +3827,3 @@ BValue call_fn32(BValue fn, BValue arg0, BValue arg1, BValue arg2, BValue arg3, } } -BValue value_from_pure_fn32(BPureFn32 fn) { - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); -} - diff --git a/c_runtime/bosatsu_runtime.c b/c_runtime/bosatsu_runtime.c index a15c0a452..3b2db2913 100644 --- a/c_runtime/bosatsu_runtime.c +++ b/c_runtime/bosatsu_runtime.c @@ -3,50 +3,7 @@ #include #include -/* -There are a few kinds of values: - -1. pure values: small ints, characters, small strings that can fit into 63 bits. -2. pointers to referenced counted values -3. pointers to static values stack allocated at startup - -to distinguish these cases we allocate pointers such that they are aligned to at least 4 byte -boundaries: - a. ends with 1: pure value - b. ends with 10: static pointer (allocated once and deleteds at the end of the world) - c. ends with 00: refcount pointer. - -We need to know which case we are in because in generic context we need to know -how to clone values. -*/ -#define TAG_MASK 0x3 -#define PURE_VALUE_TAG 0x1 -#define STATIC_VALUE_TAG 0x3 -#define POINTER_TAG 0x0 - -// Utility macros to check the tag of a value -#define IS_PURE_VALUE(ptr) (((uintptr_t)(ptr) & PURE_VALUE_TAG) == PURE_VALUE_TAG) -#define PURE_VALUE(ptr) ((uintptr_t)(ptr) >> 1) -#define IS_STATIC_VALUE(ptr) (((uintptr_t)(ptr) & TAG_MASK) == STATIC_VALUE_TAG) -#define IS_POINTER(ptr) (((uintptr_t)(ptr) & TAG_MASK) == POINTER_TAG) -#define TO_POINTER(ptr) ((uintptr_t)(ptr) & ~TAG_MASK) - -#define DEFINE_RC_STRUCT(name, fields) \ - struct name { \ - atomic_int ref_count; \ - FreeFn free; \ - fields \ - }; \ - typedef struct name name - -#define DEFINE_RC_ENUM(name, fields) \ - struct name { \ - atomic_int ref_count; \ - FreeFn free; \ - ENUM_TAG tag; \ - fields \ - }; \ - typedef struct name name +#define DEFINE_RC_ENUM(name, fields) DEFINE_RC_STRUCT(name, ENUM_TAG tag; fields) DEFINE_RC_STRUCT(RefCounted,); @@ -77,6 +34,9 @@ void free_closure(Closure1Data* s) { DEFINE_RC_ENUM(Enum0,); DEFINE_RC_STRUCT(External, void* external; FreeFn ex_free;); + +DEFINE_RC_STRUCT(BSTS_String, size_t len; char* bytes;); + // A general structure for a reference counted memory block // it is always allocated with len BValue array immediately after typedef struct _Node { @@ -177,6 +137,27 @@ void* get_external(BValue v) { return rc->external; } +void free_string(void* str) { + BSTS_String* casted = (BSTS_String*)str; + free(casted->bytes); + free(str); +} + +// this copies the bytes in, it does not take ownership +BValue bsts_string_from_utf8_bytes_copy(size_t len, char* bytes) { + BSTS_String* str = malloc(sizeof(BSTS_String)); + char* bytes_copy = malloc(sizeof(char) * len); + for(size_t i = 0; i < len; i++) { + bytes_copy[i] = bytes[i]; + } + str->len = len; + str->bytes = bytes_copy; + atomic_init(&str->ref_count, 1); + str->free = (FreeFn)free_string; + + return (BValue)str; +} + // Function to determine the type of the given value pointer and clone if necessary BValue clone_value(BValue value) { if (IS_POINTER(value)) { @@ -227,11 +208,34 @@ BValue make_static(BValue v) { return v; } +BValue read_or_build(_Atomic BValue* target, BConstruct cons) { + BValue result = atomic_load(target); + if (result == NULL) { + result = cons(); + BValue static_version = make_static(result); + BValue expected = NULL; + do { + if (atomic_compare_exchange_weak(target, &expected, static_version)) { + free_on_close(result); + break; + } else { + expected = atomic_load(target); + if (expected != NULL) { + release_value(result); + result = expected; + break; + } + } + } while (1); + } + return result; +} + // Example static BValue make_foo(); static _Atomic BValue __bvalue_foo = NULL; // Add this to the main function to construct all // the top level values before we start BValue foo() { - return CONSTRUCT(&__bvalue_foo, make_foo); -} \ No newline at end of file + return read_or_build(&__bvalue_foo, make_foo); +} diff --git a/c_runtime/bosatsu_runtime.h b/c_runtime/bosatsu_runtime.h index efb486d1f..59bde9fdd 100644 --- a/c_runtime/bosatsu_runtime.h +++ b/c_runtime/bosatsu_runtime.h @@ -4,10 +4,77 @@ #include #include +/* +There are a few kinds of values: + +1. pure values: small ints, characters, small strings that can fit into 63 bits. +2. pointers to referenced counted values +3. pointers to static values stack allocated at startup + +to distinguish these cases we allocate pointers such that they are aligned to at least 4 byte +boundaries: + a. ends with 01: pure value + b. ends with 11: static pointer (allocated once and deleteds at the end of the world) + c. ends with 00: refcount pointer. + +when it comes to functions there are three types: + a. top level pure function: ends with 1 + b. static closure (something that closes over static things, ideally we would optimize this away): ends with 10 + c. refcounted closure: ends with 00 + +Nat-like values are represented by positive integers encoded as PURE_VALUE such that +NAT(x) = (x << 1) | 1, since we don't have enough time to increment through 2^{63} values +this is a safe encoding. + +Char values are stored as unicode code points with a trailing 1. + +String values encodings, string values are like ref-counted structs with +a length and char* holding the utf-8 bytes. We could also potentially optimize +short strings by packing them literally into 63 bits with a length. + +Integer values are either pure values (signed values packed into 63 bits), +or ref-counted big integers + +We need to know which case we are in because in generic context we need to know +how to clone values. +*/ +#define TAG_MASK 0x3 +#define PURE_VALUE_TAG 0x1 +#define STATIC_VALUE_TAG 0x3 +#define POINTER_TAG 0x0 + +// Utility macros to check the tag of a value +#define IS_PURE_VALUE(ptr) (((uintptr_t)(ptr) & PURE_VALUE_TAG) == PURE_VALUE_TAG) +#define PURE_VALUE(ptr) ((uintptr_t)(ptr) >> 1) +#define IS_STATIC_VALUE(ptr) (((uintptr_t)(ptr) & TAG_MASK) == STATIC_VALUE_TAG) +#define IS_POINTER(ptr) (((uintptr_t)(ptr) & TAG_MASK) == POINTER_TAG) +#define TO_POINTER(ptr) ((uintptr_t)(ptr) & ~TAG_MASK) +#define STATIC_PUREFN(ptr) (BValue*)((uintptr_t)(ptr) | PURE_VALUE_TAG) + +#define DEFINE_RC_STRUCT(name, fields) \ + struct name { \ + atomic_int ref_count; \ + FreeFn free; \ + fields \ + }; \ + typedef struct name name + typedef void* BValue; typedef uint32_t ENUM_TAG; #include "bosatsu_decls_generated.h" +// Nat values are encoded in integers +#define BSTS_NAT_0 ((BValue)0x1) +#define BSTS_NAT_SUCC(n) ((BValue)((uintptr_t)(n) + 2)) +#define BSTS_NAT_PREV(n) ((BValue)((uintptr_t)(n) - 2)) +#define BSTS_NAT_IS_0(n) (((uintptr_t)(n)) == 0x1) +#define BSTS_NAT_GT_0(n) (((uintptr_t)(n)) != 0x1) + +#define BSTS_AND(x, y) ((x) && (y)) + +#define BSTS_TO_CHAR(x) (BValue)((x << 1) | 1) +#define BSTS_NULL_TERM_STATIC_STR(x) (BValue)(((uintptr_t)(x)) | PURE_VALUE_TAG) + // this is the free function to call on an external value typedef void (*FreeFn)(void*); // A function which constructs a BValue @@ -21,6 +88,15 @@ BValue get_struct_index(BValue v, int idx); ENUM_TAG get_variant(BValue v); BValue get_enum_index(BValue v, int idx); +// This one is not auto generated because it can always be fit into the BValue directly +BValue alloc_enum0(ENUM_TAG tag); + +BValue bsts_string_from_utf8_bytes_copy(size_t len, char* bytes); +_Bool bsts_equals_string(BValue left, BValue right); + +BValue bsts_integer_from_int(int small_int); +BValue bsts_integer_from_words_copy(_Bool is_pos, size_t size, int32_t* words); +_Bool bsts_equals_int(BValue left, BValue right); BValue alloc_external(void* eval, FreeFn free_fn); void* get_external(BValue v); @@ -36,6 +112,8 @@ void free_statics(); BValue make_static(BValue v); void free_on_close(BValue v); +BValue read_or_build(_Atomic BValue* v, BConstruct cons); + #define CONSTRUCT(target, cons) (\ {\ BValue result = atomic_load(target);\ diff --git a/c_runtime/typegen.py b/c_runtime/typegen.py index cdca044ce..b3f56ab52 100644 --- a/c_runtime/typegen.py +++ b/c_runtime/typegen.py @@ -75,7 +75,7 @@ def function_impl(size): }} BValue call_fn{size}(BValue fn, {arg_params}) {{ - if (IS_STATIC_VALUE(fn)) {{ + if (IS_PURE_VALUE(fn)) {{ BPureFn{size} pfn = (BPureFn{size})TO_POINTER(fn); return pfn({just_args}); }} @@ -85,10 +85,6 @@ def function_impl(size): BValue* data = closure_data_of({cast_to_1}rc); return rc->fn(data, {just_args}); }} -}} - -BValue value_from_pure_fn{size}(BPureFn{size} fn) {{ - return (BValue)(((uintptr_t)fn) | STATIC_VALUE_TAG); }}""" cast_to_1 = "" if size == 1 else "(Closure1Data*)" arg_params = ", ".join("BValue arg{i}".format(i = i) for i in range(size)) diff --git a/core/src/main/scala/org/bykn/bosatsu/Matchless.scala b/core/src/main/scala/org/bykn/bosatsu/Matchless.scala index 11f704767..dac23c593 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Matchless.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Matchless.scala @@ -17,6 +17,11 @@ object Matchless { def captures: List[Expr] // this is set if the function is recursive def recursiveName: Option[Bindable] + def recursionKind: RecursionKind = RecursionKind.recursive(recursiveName.isDefined) + + def args: NonEmptyList[Bindable] + def arity: Int = args.length + def body: Expr } sealed abstract class StrPart @@ -84,7 +89,7 @@ object Matchless { captures: List[Expr], recursiveName: Option[Bindable], args: NonEmptyList[Bindable], - expr: Expr + body: Expr ) extends FnExpr // this is a tail recursive function that should be compiled into a loop @@ -94,7 +99,7 @@ object Matchless { case class LoopFn( captures: List[Expr], name: Bindable, - arg: NonEmptyList[Bindable], + args: NonEmptyList[Bindable], body: Expr ) extends FnExpr { val recursiveName: Option[Bindable] = Some(name) diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/Idents.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/Idents.scala index 491b7d1d2..1bc325629 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/Idents.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/Idents.scala @@ -2,8 +2,29 @@ package org.bykn.bosatsu.codegen object Idents { - private[this] val base62Items = - (('0' to '9') ++ ('A' to 'Z') ++ ('a' to 'z')).toSet + private[this] val firstChars = + (('a' to 'z') ++ ('A' to 'Z')).toArray + + private[this] val base62ItemsArray = + (firstChars ++ ('0' to '9')).toArray + + private[this] val base62Items = base62ItemsArray.toSet + + // these are all the strings that escape as themselves + val allSimpleIdents: LazyList[String] = { + val front = firstChars.to(LazyList).map(_.toString) + val inners = base62ItemsArray.to(LazyList).map(_.toString) + + lazy val tails: LazyList[String] = inners #::: (for { + h <- inners + t <- tails + } yield h + t) + + front #::: (for { + f <- front + t <- tails + } yield f + t) + } private[this] val offset0: Int = '0'.toInt private[this] val offsetA: Int = 'A'.toInt - 10 diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala new file mode 100644 index 000000000..933d323c1 --- /dev/null +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala @@ -0,0 +1,814 @@ +package org.bykn.bosatsu.codegen.clang + +import cats.{Eval, Monad, Traverse} +import cats.data.{StateT, EitherT, NonEmptyList, Chain} +import java.math.BigInteger +import java.nio.charset.StandardCharsets +import org.bykn.bosatsu.codegen.Idents +import org.bykn.bosatsu.rankn.DataRepr +import org.bykn.bosatsu.{Identifier, Lit, Matchless, PackageName} +import org.bykn.bosatsu.Matchless.Expr +import org.bykn.bosatsu.Identifier.Bindable +import org.typelevel.paiges.Doc + +import cats.syntax.all._ + +object ClangGen { + sealed abstract class Error + object Error { + case class UnknownValue(pack: PackageName, value: Bindable) extends Error + case class InvariantViolation(message: String, expr: Expr) extends Error + case class Unbound(bn: Bindable, inside: Option[(PackageName, Bindable)]) extends Error + } + + def renderMain( + sortedEnv: Vector[NonEmptyList[(PackageName, List[(Bindable, Expr)])]], + externals: Map[(PackageName, Bindable), (Code.Include, Code.Ident)], + value: (PackageName, Bindable), + evaluator: (Code.Include, Code.Ident) + ): Either[Error, Doc] = { + val env = Impl.Env.impl + import env._ + + val trav2 = Traverse[Vector].compose[NonEmptyList] + + val res = + trav2.traverse_(sortedEnv) { case (pn, values) => + values.traverse_ { case (bindable, expr) => + renderTop(pn, bindable, expr) + } + } *> env.renderMain(value._1, value._2, evaluator._1, evaluator._2) + + val allValues: Impl.AllValues = + sortedEnv + .iterator.flatMap(_.iterator) + .flatMap { case (p, vs) => + vs.iterator.map { case (b, e) => + (p, b) -> (e, Impl.generatedName(p, b)) + } + } + .toMap + + run(allValues, externals, res) + } + + private object Impl { + type AllValues = Map[(PackageName, Bindable), (Expr, Code.Ident)] + type Externals = Map[(PackageName, Bindable), (Code.Include, Code.Ident)] + + def fullName(p: PackageName, b: Bindable): String = + p.asString + "/" + b.asString + + def generatedName(p: PackageName, b: Bindable): Code.Ident = + Code.Ident(Idents.escape("___bsts_g_", fullName(p, b))) + + trait Env { + import Matchless._ + + type T[A] + implicit val monadImpl: Monad[T] + def run(pm: AllValues, externals: Externals, t: T[Unit]): Either[Error, Doc] + def appendStatement(stmt: Code.Statement): T[Unit] + def error[A](e: => Error): T[A] + def globalIdent(pn: PackageName, bn: Bindable): T[Code.Ident] + def bind[A](bn: Bindable)(in: T[A]): T[A] + def getBinding(bn: Bindable): T[Code.Ident] + def bindAnon[A](idx: Long)(in: T[A]): T[A] + def getAnon(idx: Long): T[Code.Ident] + // a recursive function needs to remap the Bindable to the top-level mangling + def recursiveName[A](fnName: Code.Ident, bn: Bindable, isClosure: Boolean)(in: T[A]): T[A] + // used for temporary variables of type BValue + def newLocalName(tag: String): T[Code.Ident] + def newTopName(tag: String): T[Code.Ident] + def directFn(p: PackageName, b: Bindable): T[Option[Code.Ident]] + def directFn(b: Bindable): T[Option[(Code.Ident, Boolean)]] + def inTop[A](p: PackageName, bn: Bindable)(ta: T[A]): T[A] + def staticValueName(p: PackageName, b: Bindable): T[Code.Ident] + def constructorFn(p: PackageName, b: Bindable): T[Code.Ident] + + ///////////////////////////////////////// + // the below are independent of the environment implementation + ///////////////////////////////////////// + + // This name has to be impossible to give out for any other purpose + val slotsArgName: Code.Ident = Code.Ident("__bstsi_slot") + + // assign any results to result and set the condition to false + // and replace any tail calls to nm(args) with assigning args to those values + def toWhileBody(fnName: Code.Ident, args: NonEmptyList[Code.Param], isClosure: Boolean, cond: Code.Ident, result: Code.Ident, body: Code.ValueLike): Code.Block = { + + import Code._ + + def returnValue(vl: ValueLike): Statement = + (cond := FalseLit) + + (result := vl) + + def loop(vl: ValueLike): Option[Statement] = + vl match { + case Apply(fn, appArgs) if fn == fnName => + // this is a tail call + val newArgsList = + if (isClosure) appArgs.tail + else appArgs + // we know the length of appArgs must match args or the code wouldn't have compiled + val assigns = args.zipWith(NonEmptyList.fromListUnsafe(newArgsList)) { + case (Param(_, name), value) => + Assignment(name, value) + } + Some(Statements(assigns)) + case IfElseValue(c, t, f) => + // this can possible have tail calls inside the branches + (loop(t), loop(f)) match { + case (Some(t), Some(f)) => + Some(ifThenElse(c, t, f)) + case (None, Some(f)) => + Some(ifThenElse(c, returnValue(t), f)) + case (Some(t), None) => + Some(ifThenElse(c, t, returnValue(f))) + case (None, None) => None + } + case Ternary(c, t, f) => loop(IfElseValue(c, t, f)) + case WithValue(s, vl) => loop(vl).map(s + _) + case Apply(_, _) | Cast(_, _) | BinExpr(_, _, _) | Bracket(_, _) | Ident(_) | + IntLiteral(_) | PostfixExpr(_, _) | PrefixExpr(_, _) | Select(_, _) | + StrLiteral(_) => None + } + + loop(body) match { + case Some(stmt) => block(stmt) + case None => + sys.error("invariant violation: could not find tail calls in:" + + s"toWhileBody(fnName = $fnName, body = $body)") + } + } + + def bindAll[A](nel: NonEmptyList[Bindable])(in: T[A]): T[A] = + bind(nel.head) { + NonEmptyList.fromList(nel.tail) match { + case None => in + case Some(rest) => bindAll(rest)(in) + } + } + + def equalsChar(expr: Code.Expression, codePoint: Int): Code.Expression = + expr =:= Code.Ident("BSTS_TO_CHAR")(Code.IntLiteral(codePoint)) + + def pv(e: Code.ValueLike): T[Code.ValueLike] = monadImpl.pure(e) + + // The type of this value must be a C _Bool + def boolToValue(boolExpr: BoolExpr): T[Code.ValueLike] = + boolExpr match { + case EqualsLit(expr, lit) => + innerToValue(expr).flatMap { vl => + lit match { + case c @ Lit.Chr(_) => vl.onExpr { e => pv(equalsChar(e, c.toCodePoint)) }(newLocalName) + case Lit.Str(_) => + vl.onExpr { e => + literal(lit).flatMap { litStr => + Code.ValueLike.applyArgs(Code.Ident("bsts_equals_string"), + NonEmptyList(e, litStr :: Nil) + )(newLocalName) + } + }(newLocalName) + case Lit.Integer(_) => + vl.onExpr { e => + literal(lit).flatMap { litStr => + Code.ValueLike.applyArgs(Code.Ident("bsts_equals_int"), + NonEmptyList(e, litStr :: Nil) + )(newLocalName) + } + }(newLocalName) + } + } + case EqualsNat(expr, nat) => + val fn = nat match { + case DataRepr.ZeroNat => Code.Ident("BSTS_NAT_IS_0") + case DataRepr.SuccNat => Code.Ident("BSTS_NAT_GT_0") + } + innerToValue(expr).flatMap { vl => + vl.onExpr { expr => pv(fn(expr)) }(newLocalName) + } + case And(e1, e2) => + (boolToValue(e1), boolToValue(e2)) + .flatMapN { (a, b) => + Code.ValueLike.applyArgs( + Code.Ident("BSTS_AND"), + NonEmptyList(a, b :: Nil) + )(newLocalName) + } + case CheckVariant(expr, expect, _, _) => + innerToValue(expr).flatMap { vl => + // this is just get_variant(expr) == expect + vl.onExpr { expr => pv(Code.Ident("get_variant")(expr) =:= Code.IntLiteral(expect)) }(newLocalName) + } + case sl @ SearchList(lst, init, check, leftAcc) => + // TODO: ??? + println(s"TODO: implement boolToValue($sl) returning false") + pv(Code.FalseLit) + case ms @ MatchString(arg, parts, binds) => + // TODO: ??? + println(s"TODO: implement boolToValue($ms) returning false") + pv(Code.FalseLit) + case SetMut(LocalAnonMut(idx), expr) => + for { + name <- getAnon(idx) + vl <- innerToValue(expr) + } yield (name := vl) +: Code.TrueLit + case TrueConst => pv(Code.TrueLit) + } + + // We have to lift functions to the top level and not + // create any nesting + def innerFn(fn: FnExpr): T[Code.ValueLike] = + if (fn.captures.isEmpty) { + for { + ident <- newTopName("lambda") + stmt <- fnStatement(ident, fn) + _ <- appendStatement(stmt) + } yield Code.Ident("STATIC_PUREFN")(ident) + } + else { + // we create the function, then we allocate + // values for the capture + // alloc_closure(capLen, captures, fnName) + for { + ident <- newTopName("closure") + stmt <- fnStatement(ident, fn) + _ <- appendStatement(stmt) + capName <- newLocalName("captures") + capValues <- fn.captures.traverse(innerToValue(_)) + decl <- Code.ValueLike.declareArray(capName, Code.TypeIdent.BValue, capValues)(newLocalName) + } yield Code.WithValue(decl, + Code.Ident(s"alloc_closure${fn.arity}")( + Code.IntLiteral(BigInt(fn.captures.length)), + capName, + ident + ) + ) + } + + def literal(lit: Lit): T[Code.ValueLike] = + lit match { + case c @ Lit.Chr(_) => + // encoded as integers in pure values + pv(Code.Ident("BSTS_TO_CHAR")(Code.IntLiteral(c.toCodePoint))) + case Lit.Integer(toBigInteger) => + try { + val iv = toBigInteger.intValueExact() + pv(Code.Ident("bsts_integer_from_int")(Code.IntLiteral(iv))) + } + catch { + case _: ArithmeticException => + // emit the uint32 words and sign + val isPos = toBigInteger.signum >= 0 + var current = if (isPos) toBigInteger else toBigInteger.negate() + val two32 = BigInteger.ONE.shiftLeft(32) + val bldr = List.newBuilder[Code.IntLiteral] + while (current.compareTo(BigInteger.ZERO) > 0) { + bldr += Code.IntLiteral(current.mod(two32).longValue()) + current = current.shiftRight(32) + } + val lits = bldr.result() + //call: + // bsts_integer_from_words_copy(_Bool is_pos, size_t size, int32_t* words); + newLocalName("int").map { ident => + Code.DeclareArray(Code.TypeIdent.UInt32, ident, Right(lits)) +: + Code.Ident("bsts_integer_from_words_copy")( + if (isPos) Code.TrueLit else Code.FalseLit, + Code.IntLiteral(lits.length), + ident + ) + } + } + + case Lit.Str(toStr) => + // convert to utf8 and then to a literal array of bytes + val bytes = toStr.getBytes(StandardCharsets.UTF_8) + if (bytes.forall(_.toInt != 0)) { + // just send the utf8 bytes as a string to C + pv( + Code.Ident("BSTS_NULL_TERM_STATIC_STR")(Code.StrLiteral( + new String(bytes.map(_.toChar)) + )) + ) + } + else { + // We have some null bytes, we have to encode the length + val lits = + bytes.iterator.map { byte => + Code.IntLiteral(byte.toInt & 0xff) + }.toList + //call: + // bsts_string_from_utf8_bytes_copy(size_t size, char* bytes); + newLocalName("str").map { ident => + // TODO: this could be a static top level definition to initialize + // one time and avoid the copy probably, but copies are fast.... + Code.DeclareArray(Code.TypeIdent.Char, ident, Right(lits)) +: + Code.Ident("bsts_string_from_utf8_bytes_copy")( + Code.IntLiteral(lits.length), + ident + ) + } + } + } + + def innerApp(app: App): T[Code.ValueLike] = + app match { + case App(Global(pack, fnName), args) => + directFn(pack, fnName).flatMap { + case Some(ident) => + // directly invoke instead of by treating them like lambdas + args.traverse(innerToValue(_)).flatMap { argsVL => + Code.ValueLike.applyArgs(ident, argsVL)(newLocalName) + } + case None => + // the ref be holding the result of another function call + (globalIdent(pack, fnName), args.traverse(innerToValue(_))).flatMapN { (fnVL, argsVL) => + // we need to invoke call_fn(fn, arg0, arg1, ....) + // but since these are ValueLike, we need to handle more carefully + val fnSize = argsVL.length + val callFn = Code.Ident(s"call_fn$fnSize") + Code.ValueLike.applyArgs(callFn, fnVL :: argsVL)(newLocalName) + } + } + case App(Local(fnName), args) => + directFn(fnName).flatMap { + case Some((ident, isClosure)) => + // directly invoke instead of by treating them like lambdas + args.traverse(innerToValue(_)).flatMap { argsVL => + val withSlot = + if (isClosure) slotsArgName :: argsVL + else argsVL + Code.ValueLike.applyArgs(ident, withSlot)(newLocalName) + } + case None => + // the ref be holding the result of another function call + (getBinding(fnName), args.traverse(innerToValue(_))).flatMapN { (fnVL, argsVL) => + // we need to invoke call_fn(fn, arg0, arg1, ....) + // but since these are ValueLike, we need to handle more carefully + val fnSize = argsVL.length + val callFn = Code.Ident(s"call_fn$fnSize") + Code.ValueLike.applyArgs(callFn, fnVL :: argsVL)(newLocalName) + } + } + case App(MakeEnum(variant, arity, _), args) => + // to type check, we know that the arity must have the same length as args + args.traverse(innerToValue).flatMap { argsVL => + val tag = Code.IntLiteral(variant) + Code.ValueLike.applyArgs(Code.Ident(s"alloc_enum$arity"), tag :: argsVL)(newLocalName) + } + case App(MakeStruct(arity), args) => + if (arity == 1) { + // this is a new-type, just return the arg + innerToValue(args.head) + } + else { + // to type check, we know that the arity must have the same length as args + args.traverse(innerToValue).flatMap { argsVL => + Code.ValueLike.applyArgs(Code.Ident(s"alloc_struct$arity"), argsVL)(newLocalName) + } + } + case App(SuccNat, args) => + innerToValue(args.head).flatMap { arg => + Code.ValueLike.applyArgs(Code.Ident("BSTS_NAT_SUCC"), NonEmptyList.one(arg))(newLocalName) + } + case App(fn, args) => + (innerToValue(fn), args.traverse(innerToValue(_))).flatMapN { (fnVL, argsVL) => + // we need to invoke call_fn(fn, arg0, arg1, ....) + // but since these are ValueLike, we need to handle more carefully + val fnSize = argsVL.length + val callFn = Code.Ident(s"call_fn$fnSize") + Code.ValueLike.applyArgs(callFn, fnVL :: argsVL)(newLocalName) + } + } + + def innerToValue(expr: Expr): T[Code.ValueLike] = + expr match { + case fn: FnExpr => innerFn(fn) + case Let(Right(arg), argV, in) => + bind(arg) { + for { + name <- getBinding(arg) + v <- innerToValue(argV) + result <- innerToValue(in) + stmt <- Code.ValueLike.declareVar(name, Code.TypeIdent.BValue, v)(newLocalName) + } yield stmt +: result + } + case Let(Left(LocalAnon(idx)), argV, in) => + bindAnon(idx) { + for { + name <- getAnon(idx) + v <- innerToValue(argV) + result <- innerToValue(in) + stmt <- Code.ValueLike.declareVar(name, Code.TypeIdent.BValue, v)(newLocalName) + } yield stmt +: result + } + case app @ App(_, _) => innerApp(app) + case Global(pack, name) => + directFn(pack, name) + .flatMap { + case Some(nm) => + pv(Code.Ident("STATIC_PUREFN")(nm)) + case None => + // read_or_build(&__bvalue_foo, make_foo); + for { + value <- staticValueName(pack, name) + consFn <- constructorFn(pack, name) + } yield Code.Ident("read_or_build")(value.addr, consFn): Code.ValueLike + } + case Local(arg) => + directFn(arg) + .flatMap { + case Some((nm, false)) => + // a closure can't be a static name + pv(Code.Ident("STATIC_PUREFN")(nm)) + case _ => + getBinding(arg).widen + } + case ClosureSlot(idx) => + // we must be inside a closure function, so we should have a slots argument to access + pv(slotsArgName.bracket(Code.IntLiteral(BigInt(idx)))) + case LocalAnon(ident) => getAnon(ident).widen + case LocalAnonMut(ident) => getAnon(ident).widen + case LetMut(LocalAnonMut(m), span) => + bindAnon(m) { + for { + ident <- getAnon(m) + decl = Code.DeclareVar(Nil, Code.TypeIdent.BValue, ident, None) + res <- innerToValue(span) + } yield decl +: res + } + case Literal(lit) => literal(lit) + case If(cond, thenExpr, elseExpr) => + (boolToValue(cond), innerToValue(thenExpr), innerToValue(elseExpr)) + .flatMapN { (c, thenC, elseC) => + Code.ValueLike.ifThenElseV(c, thenC, elseC)(newLocalName) + } + case Always(cond, thenExpr) => + boolToValue(cond).flatMap { bv => + bv.discardValue match { + case None => innerToValue(thenExpr) + case Some(effect) => innerToValue(thenExpr).map(effect +: _) + } + } + case GetEnumElement(arg, _, index, _) => + // call get_enum_index(v, index) + innerToValue(arg).flatMap { v => + v.onExpr(e => pv(Code.Ident("get_enum_index")(e, Code.IntLiteral(index))))(newLocalName) + } + case GetStructElement(arg, index, size) => + if (size == 1) { + // this is just a new-type wrapper, ignore it + innerToValue(arg) + } + else { + // call get_struct_index(v, index) + innerToValue(arg).flatMap { v => + v.onExpr { e => + pv(Code.Ident("get_struct_index")(e, Code.IntLiteral(index))) + }(newLocalName) + } + } + case makeEnum @ MakeEnum(variant, arity, _) => + // this is a closure over variant, we rewrite this + if (arity == 0) pv(Code.Ident("alloc_enum0")(Code.IntLiteral(variant))) + else { + val named = + // safe because arity > 0 + NonEmptyList.fromListUnsafe( + Idents.allSimpleIdents.take(arity).map { nm => Identifier.Name(nm) }.toList + ) + // This relies on optimizing App(MakeEnum, _) otherwise + // it creates an infinite loop. + // Also, this we should cache creation of Lambda/Closure values + innerToValue(Lambda(Nil, None, named, App(makeEnum, named.map(Local(_))))) + } + case MakeStruct(arity) => + pv { + if (arity == 0) Code.Ident("PURE_VALUE_TAG") + else { + val allocStructFn = s"alloc_struct$arity" + Code.Ident("STATIC_PUREFN")(Code.Ident(allocStructFn)) + } + } + case ZeroNat => + pv(Code.Ident("BSTS_NAT_0")) + case SuccNat => + val arg = Identifier.Name("arg0") + // This relies on optimizing App(SuccNat, _) otherwise + // it creates an infinite loop. + // Also, this we should cache creation of Lambda/Closure values + innerToValue(Lambda(Nil, None, NonEmptyList.one(arg), + App(SuccNat, NonEmptyList.one(Local(arg))))) + case PrevNat(of) => + innerToValue(of).flatMap { argVL => + Code.ValueLike.applyArgs( + Code.Ident("BSTS_NAT_PREV"), + NonEmptyList.one(argVL) + )(newLocalName) + } + } + + def fnStatement(fnName: Code.Ident, fn: FnExpr): T[Code.Statement] = + fn match { + case Lambda(captures, name, args, expr) => + val body = innerToValue(expr).map(Code.returnValue(_)) + val body1 = name match { + case None => body + case Some(rec) => recursiveName(fnName, rec, isClosure = captures.nonEmpty)(body) + } + + bindAll(args) { + for { + argParams <- args.traverse { b => + getBinding(b).map { i => Code.Param(Code.TypeIdent.BValue, i) } + } + fnBody <- body1 + allArgs = + if (captures.isEmpty) argParams + else { + Code.Param(Code.TypeIdent.BValue.ptr, slotsArgName) :: argParams + } + } yield Code.DeclareFn(Nil, Code.TypeIdent.BValue, fnName, allArgs.toList, Some(Code.block(fnBody))) + } + case LoopFn(captures, nm, args, body) => + recursiveName(fnName, nm, isClosure = captures.nonEmpty) { + bindAll(args) { + for { + cond <- newLocalName("cond") + res <- newLocalName("res") + bodyVL <- innerToValue(body) + argParams <- args.traverse { b => + getBinding(b).map { i => Code.Param(Code.TypeIdent.BValue, i) } + } + whileBody = toWhileBody(fnName, argParams, isClosure = captures.nonEmpty, cond = cond, result = res, body = bodyVL) + fnBody = Code.block( + Code.DeclareVar(Nil, Code.TypeIdent.Bool, cond, Some(Code.TrueLit)), + Code.DeclareVar(Nil, Code.TypeIdent.BValue, res, None), + Code.While(cond, whileBody), + Code.Return(Some(res)) + ) + allArgs = + if (captures.isEmpty) argParams + else { + Code.Param(Code.TypeIdent.BValue.ptr, slotsArgName) :: argParams + } + } yield Code.DeclareFn(Nil, Code.TypeIdent.BValue, fnName, allArgs.toList, Some(fnBody)) + } + } + } + + def renderTop(p: PackageName, b: Bindable, expr: Expr): T[Unit] = + inTop(p, b) { expr match { + case fn: FnExpr => + for { + fnName <- globalIdent(p, b) + stmt <- fnStatement(fnName, fn) + _ <- appendStatement(stmt) + } yield () + case someValue => + // we materialize an Atomic value to hold the static data + // then we generate a function to populate the value + for { + vl <- innerToValue(someValue) + value <- staticValueName(p, b) + consFn <- constructorFn(p, b) + _ <- appendStatement(Code.DeclareVar( + Code.Attr.Static :: Nil, + Code.TypeIdent.AtomicBValue, + value, + Some(Code.IntLiteral.Zero) + )) + _ <- appendStatement(Code.DeclareFn( + Code.Attr.Static :: Nil, + Code.TypeIdent.BValue, + consFn, + Nil, + Some(Code.block(Code.returnValue(vl))) + )) + } yield () + } + } + + def renderMain(p: PackageName, b: Bindable, evalInc: Code.Include, evalFn: Code.Ident): T[Unit] + } + + object Env { + def impl: Env = { + def catsMonad[S]: Monad[StateT[EitherT[Eval, Error, *], S, *]] = implicitly + + new Env { + case class State( + allValues: AllValues, + externals: Externals, + includeSet: Set[Code.Include], + includes: Chain[Code.Include], + stmts: Chain[Code.Statement], + currentTop: Option[(PackageName, Bindable)], + binds: Map[Bindable, NonEmptyList[Either[((Code.Ident, Boolean), Int), Int]]], + counter: Long + ) { + def finalFile: Doc = + Doc.intercalate(Doc.hardLine, includes.iterator.map(Code.toDoc(_)).toList) + + Doc.hardLine + Doc.hardLine + + Doc.intercalate(Doc.hardLine + Doc.hardLine, stmts.iterator.map(Code.toDoc(_)).toList) + } + + object State { + def init(allValues: AllValues, externals: Externals): State = { + val defaultIncludes = + List(Code.Include(true, "bosatsu_runtime.h")) + + State(allValues, externals, Set.empty ++ defaultIncludes, Chain.fromSeq(defaultIncludes), Chain.empty, + None, Map.empty, 0L + ) + } + } + + type T[A] = StateT[EitherT[Eval, Error, *], State, A] + + implicit val monadImpl: Monad[T] = catsMonad[State] + + def run(pm: AllValues, externals: Externals, t: T[Unit]): Either[Error, Doc] = + t.run(State.init(pm, externals)) + .value // get the value out of the EitherT + .value // evaluate the Eval + .map(_._1.finalFile) + + def appendStatement(stmt: Code.Statement): T[Unit] = + StateT.modify(s => s.copy(stmts = s.stmts :+ stmt)) + + def errorRes[A](e: => Error): EitherT[Eval, Error, A] = + EitherT[Eval, Error, A](Eval.later(Left(e))) + + def error[A](e: => Error): T[A] = + StateT(_ => errorRes(e)) + + def result[A](s: State, a: A): EitherT[Eval, Error, (State, A)] = + EitherT[Eval, Error, (State, A)]( + Eval.now(Right((s, a))) + ) + + def globalIdent(pn: PackageName, bn: Bindable): T[Code.Ident] = + StateT { s => + val key = (pn, bn) + s.externals.get(key) match { + case Some((incl, ident)) => + val withIncl = + if (s.includeSet(incl)) s + else s.copy(includeSet = s.includeSet + incl, includes = s.includes :+ incl) + + result(withIncl, ident) + case None => + s.allValues.get(key) match { + case Some((_, ident)) => result(s, ident) + case None => errorRes(Error.UnknownValue(pn, bn)) + } + } + } + + def bind[A](bn: Bindable)(in: T[A]): T[A] = { + val init: T[Unit] = StateT { s => + val v = s.binds.get(bn) match { + case None => NonEmptyList.one(Right(0)) + case Some(items @ NonEmptyList(Right(idx), _)) => + Right(idx + 1) :: items + case Some(items @ NonEmptyList(Left((_, idx)), _)) => + Right(idx + 1) :: items + } + result(s.copy(binds = s.binds.updated(bn, v)), ()) + } + + val uninit: T[Unit] = StateT { s => + s.binds.get(bn) match { + case Some(NonEmptyList(_, tail)) => + val s1 = NonEmptyList.fromList(tail) match { + case None => + s.copy(binds = s.binds - bn) + case Some(prior) => + s.copy(binds = s.binds.updated(bn, prior)) + } + result(s1, ()) + case None => sys.error(s"bindable $bn no longer in $s") + } + } + + for { + _ <- init + a <- in + _ <- uninit + } yield a + } + def getBinding(bn: Bindable): T[Code.Ident] = + StateT { s => + s.binds.get(bn) match { + case Some(stack) => + stack.head match { + case Right(idx) => + result(s, Code.Ident(Idents.escape("__bsts_b_", bn.asString + idx.toString))) + case Left(((ident, _), _)) => + // TODO: suspicious to ignore isClosure here + result(s, ident) + } + case None => errorRes(Error.Unbound(bn, s.currentTop)) + } + } + def bindAnon[A](idx: Long)(in: T[A]): T[A] = + // in the future we see the scope of the binding which matters for GC, but here + // we don't care + in + + def getAnon(idx: Long): T[Code.Ident] = + monadImpl.pure(Code.Ident(Idents.escape("__bsts_a_", idx.toString))) + + // a recursive function needs to remap the Bindable to the top-level mangling + def recursiveName[A](fnName: Code.Ident, bn: Bindable, isClosure: Boolean)(in: T[A]): T[A] = { + val init: T[Unit] = StateT { s => + val entry = (fnName, isClosure) + val v = s.binds.get(bn) match { + case None => NonEmptyList.one(Left((entry, -1))) + case Some(items @ NonEmptyList(Right(idx), _)) => + Left((entry, idx)) :: items + case Some(items @ NonEmptyList(Left((_, idx)), _)) => + Left((entry, idx)) :: items + } + result(s.copy(binds = s.binds.updated(bn, v)), ()) + } + + val uninit: T[Unit] = StateT { s => + s.binds.get(bn) match { + case Some(NonEmptyList(_, tail)) => + val s1 = NonEmptyList.fromList(tail) match { + case None => + s.copy(binds = s.binds - bn) + case Some(prior) => + s.copy(binds = s.binds.updated(bn, prior)) + } + result(s1, ()) + case None => sys.error(s"bindable $bn no longer in $s") + } + } + + for { + _ <- init + a <- in + _ <- uninit + } yield a + } + + val nextCnt: T[Long] = + StateT { s => + val cnt = s.counter + val s1 = s.copy(counter = cnt + 1L) + result(s1, cnt) + } + + // used for temporary variables of type BValue + def newLocalName(tag: String): T[Code.Ident] = + nextCnt.map { cnt => + Code.Ident(Idents.escape("__bsts_l_", tag + cnt.toString)) + } + def newTopName(tag: String): T[Code.Ident] = + nextCnt.map { cnt => + Code.Ident(Idents.escape("__bsts_t_", tag + cnt.toString)) + } + // record that this name is a top level function, so applying it can be direct + def directFn(pack: PackageName, b: Bindable): T[Option[Code.Ident]] = + StateT { s => + s.allValues.get((pack, b)) match { + case Some((_: Matchless.FnExpr, ident)) => + result(s, Some(ident)) + case _ => result(s, None) + } + } + + def directFn(b: Bindable): T[Option[(Code.Ident, Boolean)]] = + StateT { s => + s.binds.get(b) match { + case Some(NonEmptyList(Left((c, _)), _)) => + result(s, Some(c)) + case _ => + result(s, None) + } + } + + def inTop[A](p: PackageName, bn: Bindable)(ta: T[A]): T[A] = + for { + _ <- StateT { (s: State) => result(s.copy(currentTop = Some((p, bn))), ())} + a <- ta + _ <- StateT { (s: State) => result(s.copy(currentTop = None), ()) } + } yield a + + def staticValueName(p: PackageName, b: Bindable): T[Code.Ident] = + monadImpl.pure(Code.Ident(Idents.escape("___bsts_s_", fullName(p, b)))) + def constructorFn(p: PackageName, b: Bindable): T[Code.Ident] = + monadImpl.pure(Code.Ident(Idents.escape("___bsts_c_", fullName(p, b)))) + + def renderMain(p: PackageName, b: Bindable, evalInc: Code.Include, evalFn: Code.Ident): T[Unit] = + // TODO ??? + monadImpl.unit + } + } + } + } +} \ No newline at end of file diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala index 28ad90da7..565be64b4 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala @@ -1,9 +1,13 @@ package org.bykn.bosatsu.codegen.clang +import cats.Monad +import cats.data.{NonEmptyList, NonEmptyChain} +import java.nio.charset.StandardCharsets import org.typelevel.paiges.Doc -import cats.data.NonEmptyList import scala.language.implicitConversions +import cats.syntax.all._ + sealed trait Code object Code { @@ -34,6 +38,12 @@ object Code { case class UnionType(name: String) extends ComplexType case class Named(name: String) extends TypeIdent case class Ptr(tpe: TypeIdent) extends TypeIdent + val Int: TypeIdent = Named("int") + val UInt32: TypeIdent = Named("uint32_t") + val Char: TypeIdent = Named("char") + val BValue: TypeIdent = Named("BValue") + val AtomicBValue: TypeIdent = Named("_Atomic BValue") + val Bool: TypeIdent = Named("_Bool") private val structDoc = Doc.text("struct ") private val unionDoc = Doc.text("union ") @@ -47,9 +57,67 @@ object Code { } } - sealed trait Expression extends Code { - def :=(rhs: Expression): Statement = - Assignment(this, rhs) + sealed trait ValueLike { + def +:(prefix: Statement): ValueLike = + ValueLike.prefix(prefix, this) + + def discardValue: Option[Statement] = + this match { + case _: Expression => None + case WithValue(stmt, vl) => + vl.discardValue match { + case None => Some(stmt) + case Some(rhs) => Some(stmt + rhs) + } + case IfElseValue(cond, thenC, elseC) => + (thenC.discardValue, elseC.discardValue) match { + case (Some(ts), Some(es)) => Some(ifThenElse(cond, ts, es)) + case (Some(ts), None) => Some(IfElse(NonEmptyList.one(cond -> block(ts)), None)) + case (None, Some(es)) => + // if (cond) {} else {es} == if (!cond) { es } + Some(IfElse(NonEmptyList.one(!cond -> block(es)), None)) + case (None, None) => None + } + } + + def onExpr[F[_]: Monad](fn: Expression => F[ValueLike])(newLocalName: String => F[Code.Ident]): F[ValueLike] = + this match { + case expr: Expression => fn(expr) + case WithValue(stmt, vl) => vl.onExpr[F](fn)(newLocalName).map(stmt +: _) + case branch @ IfElseValue(_, _, _) => + for { + resIdent <- newLocalName("branch_res") + value <- fn(resIdent) + } yield ( + // Assign branchCond to a temp variable in both branches + // and then use it so we don't exponentially blow up the code + // size + (Code.DeclareVar(Nil, Code.TypeIdent.BValue, resIdent, None) + + (resIdent := branch)) +: value + ) + } + + def exprToStatement[F[_]: Monad](fn: Expression => F[Statement])(newLocalName: String => F[Code.Ident]): F[Statement] = + this match { + case expr: Expression => fn(expr) + case WithValue(stmt, vl) => vl.exprToStatement[F](fn)(newLocalName).map(stmt + _) + case branch @ IfElseValue(_, _, _) => + for { + resIdent <- newLocalName("branch_res") + last <- fn(resIdent) + } yield + // Assign branchCond to a temp variable in both branches + // and then use it so we don't exponentially blow up the code + // size + (Code.DeclareVar(Nil, Code.TypeIdent.BValue, resIdent, None) + + (resIdent := branch)) + + last + } + } + + sealed trait Expression extends Code with ValueLike { + def :=(rhs: ValueLike): Statement = + ValueLike.assign(this, rhs) def ret: Statement = Return(Some(this)) @@ -72,11 +140,112 @@ object Code { def -(that: Expression): Expression = bin(BinOp.Sub, that) def *(that: Expression): Expression = bin(BinOp.Mult, that) def /(that: Expression): Expression = bin(BinOp.Div, that) + def unary_! : Expression = PrefixExpr(PrefixUnary.Not, this) + def =:=(that: Expression): Expression = bin(BinOp.Eq, that) def postInc: Expression = PostfixExpr(this, PostfixUnary.Inc) def postDec: Expression = PostfixExpr(this, PostfixUnary.Dec) } + ///////////////////////// + // Here are all the ValueLike + ///////////////////////// + + // this prepares an expression with a number of statements + case class WithValue(statement: Statement, value: ValueLike) extends ValueLike + // At least one of thenCond or elseCond should not be an expression + case class IfElseValue(cond: Expression, thenCond: ValueLike, elseCond: ValueLike) extends ValueLike + + object ValueLike { + def applyArgs[F[_]: Monad]( + fn: ValueLike, + args: NonEmptyList[ValueLike] + )(newLocalName: String => F[Code.Ident]): F[ValueLike] = + fn.onExpr { fnExpr => + def loop(rest: List[ValueLike], acc: NonEmptyList[Expression]): F[ValueLike] = + rest match { + case Nil => Monad[F].pure[ValueLike](fnExpr(acc.reverse.toList: _*)) + case h :: t => h.onExpr { hexpr => loop(t, hexpr :: acc) }(newLocalName) + } + + args.head.onExpr { hexpr => + loop(args.tail, NonEmptyList.one(hexpr)) + }(newLocalName) + }(newLocalName) + + + def declareArray[F[_]: Monad](ident: Ident, tpe: TypeIdent, values: List[ValueLike])(newLocalName: String => F[Code.Ident]): F[Statement] = { + def loop(values: List[ValueLike], acc: List[Expression]): F[Statement] = + values match { + case Nil => Monad[F].pure(DeclareArray(tpe, ident, Right(acc.reverse))) + case (e: Expression) :: tail => + loop(tail, e :: acc) + case h :: tail => + h.exprToStatement { e => + loop(tail, e :: acc) + }(newLocalName) + } + + loop(values, Nil) + } + + def declareVar[F[_]: Monad]( + ident: Ident, + tpe: TypeIdent, + value: ValueLike)(newLocalName: String => F[Code.Ident]): F[Statement] = + value.exprToStatement[F] { expr => + Monad[F].pure(DeclareVar(Nil, tpe, ident, Some(expr))) + }(newLocalName) + + def prefix(stmt: Statement, of: ValueLike): ValueLike = + of match { + case (_: Expression) | (_: IfElseValue) => WithValue(stmt, of) + case WithValue(stmt1, v) => WithValue(stmt + stmt1, v) + } + + def assign(left: Expression, rhs: ValueLike): Statement = + rhs match { + case expr: Expression => Assignment(left, expr) + case WithValue(stmt, v) => stmt + (left := v) + case IfElseValue(cond, thenC, elseC) => + ifThenElse(cond, left := thenC, left := elseC) + } + + def ifThenElseV[F[_]: Monad](cond: Code.ValueLike, thenC: Code.ValueLike, elseC: Code.ValueLike)(newLocalName: String => F[Code.Ident]): F[Code.ValueLike] = { + cond match { + case expr: Code.Expression => + Monad[F].pure { + (thenC, elseC) match { + case (thenX: Expression, elseX: Expression) => Ternary(expr, thenX, elseX) + case _ => IfElseValue(expr, thenC, elseC) + } + } + case Code.WithValue(stmt, v) => + ifThenElseV(v, thenC, elseC)(newLocalName).map(stmt +: _) + case branchCond @ Code.IfElseValue(_, _, _) => + for { + condIdent <- newLocalName("cond") + res <- ifThenElseV(condIdent, thenC, elseC)(newLocalName) + } yield { + // Assign branchCond to a temp variable in both branches + // and then use it so we don't exponentially blow up the code + // size + (Code.DeclareVar(Nil, Code.TypeIdent.Bool, condIdent, None) + + (condIdent := branchCond)) +: + res + } + } + } + } + + def returnValue(vl: ValueLike): Statement = + vl match { + case expr: Expression => Return(Some(expr)) + case WithValue(stmt, v) => stmt + returnValue(v) + case IfElseValue(cond, thenC, elseC) => + ifThenElse(cond, returnValue(thenC), returnValue(elseC)) + } + sealed abstract class BinOp(repr: String) { val toDoc: Doc = Doc.text(repr) } @@ -131,6 +300,18 @@ object Code { implicit def fromString(str: String): Ident = Ident(str) } case class IntLiteral(value: BigInt) extends Expression + case class StrLiteral(value: String) extends Expression + object IntLiteral { + val One: IntLiteral = IntLiteral(BigInt(1)) + val Zero: IntLiteral = IntLiteral(BigInt(0)) + + def apply(i: Int): IntLiteral = IntLiteral(BigInt(i)) + def apply(i: Long): IntLiteral = IntLiteral(BigInt(i)) + } + + def TrueLit: Expression = IntLiteral.One + def FalseLit: Expression = IntLiteral.Zero + case class Cast(tpe: TypeIdent, expr: Expression) extends Expression case class Apply(fn: Expression, args: List[Expression]) extends Expression case class Select(target: Expression, name: Ident) extends Expression @@ -144,7 +325,10 @@ object Code { def toDoc: Doc = TypeIdent.toDoc(tpe) + Doc.space + Doc.text(name.name) } - sealed trait Statement extends Code + sealed trait Statement extends Code { + def +(stmt: Statement): Statement = Statements.combine(this, stmt) + def :+(vl: ValueLike): ValueLike = (this +: vl) + } case class Assignment(target: Expression, value: Expression) extends Statement case class DeclareArray(tpe: TypeIdent, ident: Ident, values: Either[Int, List[Expression]]) extends Statement case class DeclareVar(attrs: List[Attr], tpe: TypeIdent, ident: Ident, value: Option[Expression]) extends Statement @@ -155,6 +339,26 @@ object Code { case class Block(items: NonEmptyList[Statement]) extends Statement { def doWhile(cond: Expression): Statement = DoWhile(this, cond) } + // nothing more than a collection of statements + case class Statements(items: NonEmptyChain[Statement]) extends Statement + object Statements { + def apply(nel: NonEmptyList[Statement]): Statements = + Statements(NonEmptyChain.fromNonEmptyList(nel)) + + def combine(first: Statement, last: Statement): Statement = + first match { + case Statements(items) => + last match { + case Statements(rhs) => Statements(items ++ rhs) + case notStmts => Statements(items :+ notStmts) + } + case notBlock => + last match { + case Statements(rhs) => Statements(notBlock +: rhs) + case notStmts => Statements(NonEmptyChain.of(notBlock, notStmts)) + } + } + } case class IfElse(ifs: NonEmptyList[(Expression, Block)], elseCond: Option[Block]) extends Statement case class DoWhile(block: Block, whileCond: Expression) extends Statement case class Effect(expr: Expression) extends Statement @@ -164,8 +368,20 @@ object Code { val returnVoid: Statement = Return(None) def block(item: Statement, rest: Statement*): Block = - Block(NonEmptyList(item, rest.toList)) + item match { + case block @ Block(_) if rest.isEmpty => block + case _ => Block(NonEmptyList(item, rest.toList)) + } + def ifThenElse(cond: Expression, thenCond: Statement, elseCond: Statement): Statement = { + val first = cond -> block(thenCond) + elseCond match { + case IfElse(ifs, elseCond) => + IfElse(first :: ifs, elseCond) + case notIfElse => + IfElse(NonEmptyList.one(first), Some(block(notIfElse))) + } + } private val equalsDoc = Doc.text(" = ") private val semiDoc = Doc.char(';') private val typeDefDoc = Doc.text("typedef ") @@ -218,6 +434,25 @@ object Code { c match { case Ident(n) => Doc.text(n) case IntLiteral(bi) => Doc.str(bi) + case StrLiteral(str) => + val result = new java.lang.StringBuilder() + val bytes = str.getBytes(StandardCharsets.US_ASCII) + bytes.foreach { c => + val cint = c.toInt & 0xFF + if (25 <= cint && cint <= 126) { + result.append(cint.toChar) + } + else if (cint == 92) { // this is \ + result.append("\\\\") + } + else if (cint == 34) { // this is " + result.append("\\\"") + } + else { + result.append(s"\\x${java.lang.Integer.toHexString(cint)}") + } + } + quoteDoc + (Doc.text(result.toString()) + quoteDoc) case Cast(tpe, expr) => val edoc = expr match { case Ident(n) => Doc.text(n) @@ -316,7 +551,7 @@ object Code { Doc.intercalate(Doc.space, attrs.map(a => Attr.toDoc(a))) + Doc.space } - val paramDoc = Doc.intercalate(Doc.line, args.map(_.toDoc)).nested(4).grouped + val paramDoc = Doc.intercalate(commaLine, args.map(_.toDoc)).nested(4).grouped val prefix = Doc.intercalate(Doc.space, (attrDoc + TypeIdent.toDoc(tpe)) :: @@ -342,6 +577,8 @@ object Code { case Some(expr) => returnSpace + toDoc(expr) + semiDoc } case Block(items) => curlyBlock(items.toList) { s => toDoc(s) } + case Statements(items) => + Doc.intercalate(Doc.line, items.toNonEmptyList.toList.map(toDoc(_))) case IfElse(ifs, els) => //"if (ex) {} else if" val (fcond, fblock) = ifs.head diff --git a/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala b/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala index 6c5014b57..50292978e 100644 --- a/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala +++ b/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala @@ -176,7 +176,7 @@ class MatchlessTest extends AnyFunSuite { TestUtils.checkMatchless(""" x = 1 """) { binds => - val map = binds.toMap + val map = binds(TestUtils.testPackage).toMap assert(map.contains(Identifier.Name("x"))) assert(map(Identifier.Name("x")) == Matchless.Literal(Lit(1))) diff --git a/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala b/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala index c4593e097..9fb15c714 100644 --- a/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala +++ b/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala @@ -101,7 +101,7 @@ object TestUtils { def checkMatchless[A]( statement: String - )(fn: List[(Identifier.Bindable, Matchless.Expr)] => A): A = { + )(fn: Map[PackageName, List[(Identifier.Bindable, Matchless.Expr)]] => A): A = { val stmts = Parser.unsafeParse(Statement.parser, statement) Package.inferBody(testPackage, Nil, stmts).strictToValidated match { case Validated.Invalid(errs) => @@ -122,7 +122,7 @@ object TestUtils { try { implicit val ec = Par.ecFromService(srv) val comp = MatchlessFromTypedExpr.compile(pm) - fn(comp(testPackage)) + fn(comp) } finally Par.shutdownService(srv) } diff --git a/core/src/test/scala/org/bykn/bosatsu/codegen/IdentsTest.scala b/core/src/test/scala/org/bykn/bosatsu/codegen/IdentsTest.scala index f856abe3f..75edc3ba8 100644 --- a/core/src/test/scala/org/bykn/bosatsu/codegen/IdentsTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/codegen/IdentsTest.scala @@ -16,6 +16,16 @@ class IdentsTest extends munit.ScalaCheckSuite { } } + test("allSimpleIdents escape to identity") { + Idents.allSimpleIdents.take(10000).foreach { str => + assertEquals(Idents.escape("", str), str) + } + } + + test("allSimpleIdents are distinct") { + assertEquals(Idents.allSimpleIdents.take(10000).toSet.size, 10000) + } + property("escape starts with prefix") { forAll { (prefix: String, content: String) => assert(Idents.escape(prefix, content).startsWith(prefix)) diff --git a/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala b/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala new file mode 100644 index 000000000..b20432673 --- /dev/null +++ b/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala @@ -0,0 +1,131 @@ +package org.bykn.bosatsu.codegen.clang + +import cats.data.NonEmptyList +import org.bykn.bosatsu.codegen.Idents +import org.bykn.bosatsu.{PackageName, TestUtils, Identifier, Predef} +import Identifier.Name + +class ClangGenTest extends munit.FunSuite { + val predef_c = Code.Include(true, "bosatsu_predef.h") + + def predef(s: String) = + (PackageName.PredefName -> Name(s)) -> (predef_c, + Code.Ident(Idents.escape("__bsts_predef_", s))) + + def assertPredefFns(fns: String*)(matches: String)(implicit loc: munit.Location) = + TestUtils.checkMatchless(""" +x = 1 +""") { matchlessMap0 => + + val fnSet = fns.toSet + val matchlessMap = matchlessMap0 + .flatMap { + case (k, vs) if k == PackageName.PredefName => + (k -> vs.filter(tup => fnSet(tup._1.asString))) :: Nil + case _ => Nil + } + .toMap + + val res = ClangGen.renderMain( + sortedEnv = Vector( + NonEmptyList.one(PackageName.PredefName -> matchlessMap(PackageName.PredefName)), + ), + externals = + Predef.jvmExternals.toMap.keys.iterator.map { case (_, n) => predef(n) }.toMap, + value = (PackageName.PredefName, Identifier.Name(fns.last)), + evaluator = (Code.Include(true, "eval.h"), Code.Ident("evaluator_run")) + ) + + res match { + case Right(d) => assertEquals(d.render(80), matches) + case Left(e) => fail(e.toString) + } + } + + test("check build_List") { + assertPredefFns("build_List")("""#include "bosatsu_runtime.h" + +BValue __bsts_t_lambda0(BValue __bsts_b_a0, BValue __bsts_b_b0) { + return alloc_enum2(1, __bsts_b_a0, __bsts_b_b0); +} + +BValue ___bsts_g_Bosatsu_l_Predef_l_build__List(BValue __bsts_b_fn0) { + return call_fn2(__bsts_b_fn0, + STATIC_PUREFN(__bsts_t_lambda0), + alloc_enum0(0)); +}""") + } + test("check foldr_List") { + assertPredefFns("foldr_List")("""#include "bosatsu_runtime.h" + +BValue __bsts_t_closure0(BValue* __bstsi_slot, BValue __bsts_b_list1) { + if (get_variant(__bsts_b_list1) == (0)) { + return __bstsi_slot[0]; + } + else { + BValue __bsts_b_h0 = get_enum_index(__bsts_b_list1, 0); + BValue __bsts_b_t0 = get_enum_index(__bsts_b_list1, 1); + return call_fn2(__bstsi_slot[1], + __bsts_b_h0, + __bsts_t_closure0(__bstsi_slot, __bsts_b_t0)); + } +} + +BValue ___bsts_g_Bosatsu_l_Predef_l_foldr__List(BValue __bsts_b_list0, + BValue __bsts_b_fn0, + BValue __bsts_b_acc0) { + BValue __bsts_l_captures1[2] = { __bsts_b_acc0, __bsts_b_fn0 }; + BValue __bsts_b_loop0 = alloc_closure1(2, + __bsts_l_captures1, + __bsts_t_closure0); + return call_fn1(__bsts_b_loop0, __bsts_b_list0); +}""") + } + + test("check foldLeft and reverse_concat") { + assertPredefFns("foldLeft", "reverse_concat")("""#include "bosatsu_runtime.h" + +BValue __bsts_t_closure0(BValue* __bstsi_slot, + BValue __bsts_b_lst1, + BValue __bsts_b_item1) { + _Bool __bsts_l_cond1 = 1; + BValue __bsts_l_res2; + while (__bsts_l_cond1) { + if (get_variant(__bsts_b_lst1) == (0)) { + __bsts_l_cond1 = 0; + __bsts_l_res2 = __bsts_b_item1; + } + else { + BValue __bsts_b_head0 = get_enum_index(__bsts_b_lst1, 0); + BValue __bsts_b_tail0 = get_enum_index(__bsts_b_lst1, 1); + __bsts_b_lst1 = __bsts_b_tail0; + __bsts_b_item1 = call_fn2(__bstsi_slot[0], + __bsts_b_item1, + __bsts_b_head0); + } + } + return __bsts_l_res2; +} + +BValue ___bsts_g_Bosatsu_l_Predef_l_foldLeft(BValue __bsts_b_lst0, + BValue __bsts_b_item0, + BValue __bsts_b_fn0) { + BValue __bsts_l_captures3[1] = { __bsts_b_fn0 }; + BValue __bsts_b_loop0 = alloc_closure2(1, + __bsts_l_captures3, + __bsts_t_closure0); + return call_fn2(__bsts_b_loop0, __bsts_b_lst0, __bsts_b_item0); +} + +BValue __bsts_t_lambda4(BValue __bsts_b_tail0, BValue __bsts_b_h0) { + return alloc_enum2(1, __bsts_b_h0, __bsts_b_tail0); +} + +BValue ___bsts_g_Bosatsu_l_Predef_l_reverse__concat(BValue __bsts_b_front0, + BValue __bsts_b_back0) { + return ___bsts_g_Bosatsu_l_Predef_l_foldLeft(__bsts_b_front0, + __bsts_b_back0, + STATIC_PUREFN(__bsts_t_lambda4)); +}""") + } +} \ No newline at end of file From 6c9f52e80b40a07bdb829d2b55465f0ae31bf34d Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Mon, 18 Nov 2024 07:42:36 -1000 Subject: [PATCH 05/11] Test c generation of all predef (#1259) * Test c generation of all predef * make node tests pass --- .../bosatsu/codegen/clang/ClangGenTest.scala | 67 +++++++++++++ .../codegen/python/PythonGenTest.scala | 32 +----- .../main/scala/org/bykn/bosatsu/FfiCall.scala | 12 +-- .../scala/org/bykn/bosatsu/MainModule.scala | 4 +- .../scala/org/bykn/bosatsu/PackageMap.scala | 10 +- .../bykn/bosatsu/codegen/clang/ClangGen.scala | 97 +++++++++++++++---- .../org/bykn/bosatsu/codegen/clang/Code.scala | 6 +- .../scala/org/bykn/bosatsu/rankn/Type.scala | 9 ++ .../scala/org/bykn/bosatsu/TestUtils.scala | 28 +++++- .../bosatsu/codegen/clang/ClangGenTest.scala | 22 ++++- 10 files changed, 215 insertions(+), 72 deletions(-) create mode 100644 cli/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala diff --git a/cli/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala b/cli/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala new file mode 100644 index 000000000..11583c22e --- /dev/null +++ b/cli/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala @@ -0,0 +1,67 @@ +package org.bykn.bosatsu.codegen.clang + +import cats.data.NonEmptyList +import org.bykn.bosatsu.{PackageName, PackageMap, TestUtils, Identifier, Predef} +import Identifier.Name +import org.bykn.bosatsu.MatchlessFromTypedExpr + +import org.bykn.bosatsu.DirectEC.directEC + +class ClangGenTest extends munit.FunSuite { + val predef_c = Code.Include(true, "bosatsu_predef.h") + + def predef(s: String, arity: Int) = + (PackageName.PredefName -> Name(s)) -> (predef_c, + ClangGen.generatedName(PackageName.PredefName, Name(s)), + arity) + + val jvmExternals = { + val ext = Predef.jvmExternals.toMap.iterator.map { case ((_, n), ffi) => predef(n, ffi.arity) } + .toMap[(PackageName, Identifier), (Code.Include, Code.Ident, Int)] + + { (pn: (PackageName, Identifier)) => ext.get(pn) } + } + + def md5HashToHex(content: String): String = { + val md = java.security.MessageDigest.getInstance("MD5") + val digest = md.digest(content.getBytes("UTF-8")) + digest.map("%02x".format(_)).mkString + } + def testFilesCompilesToHash(path0: String, paths: String*)(hashHex: String)(implicit loc: munit.Location) = { + val pm: PackageMap.Typed[Any] = TestUtils.compileFile(path0, paths*) + /* + val exCode = ClangGen.generateExternalsStub(pm) + println(exCode.render(80)) + sys.error("stop") + */ + val matchlessMap = MatchlessFromTypedExpr.compile(pm) + val topoSort = pm.topoSort.toSuccess.get + val sortedEnv = cats.Functor[Vector].compose[NonEmptyList].map(topoSort) { pn => + (pn, matchlessMap(pn)) + } + + val res = ClangGen.renderMain( + sortedEnv = sortedEnv, + externals = jvmExternals, + value = (PackageName.PredefName, Identifier.Name("ignored")), + evaluator = (Code.Include(true, "eval.h"), Code.Ident("evaluator_run")) + ) + + res match { + case Right(d) => + val everything = d.render(80) + val hashed = md5HashToHex(everything) + assertEquals(hashed, hashHex, s"compilation didn't match. Compiled code:\n\n${"//" * 40}\n\n$everything") + case Left(e) => fail(e.toString) + } + } + + test("test_workspace/Ackermann.bosatsu") { + /* + To inspect the code, change the hash, and it will print the code out + */ + testFilesCompilesToHash("test_workspace/Ackermann.bosatsu")( + "46716ef3c97cf2a79bf17d4033d55854" + ) + } +} \ No newline at end of file diff --git a/cli/src/test/scala/org/bykn/bosatsu/codegen/python/PythonGenTest.scala b/cli/src/test/scala/org/bykn/bosatsu/codegen/python/PythonGenTest.scala index 7c6f197d5..391d98b67 100644 --- a/cli/src/test/scala/org/bykn/bosatsu/codegen/python/PythonGenTest.scala +++ b/cli/src/test/scala/org/bykn/bosatsu/codegen/python/PythonGenTest.scala @@ -1,17 +1,12 @@ package org.bykn.bosatsu.codegen.python import cats.Show -import cats.data.NonEmptyList import java.io.{ByteArrayInputStream, InputStream} -import java.nio.file.{Paths, Files} import java.util.concurrent.Semaphore import org.bykn.bosatsu.{ - PackageMap, MatchlessFromTypedExpr, - Parser, - Package, - LocationMap, - PackageName + PackageName, + TestUtils } import org.scalacheck.Gen import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ @@ -24,6 +19,8 @@ import org.python.core.{PyInteger, PyFunction, PyObject, PyTuple} import org.bykn.bosatsu.DirectEC.directEC import org.scalatest.funsuite.AnyFunSuite +import TestUtils.compileFile + // Jython seems to have some thread safety issues object JythonBarrier { private val sem = new Semaphore(1) @@ -87,27 +84,6 @@ class PythonGenTest extends AnyFunSuite { } } - def compileFile(path: String, rest: String*): PackageMap.Typed[Any] = { - def toS(s: String): String = - new String(Files.readAllBytes(Paths.get(s)), "UTF-8") - - val packNEL = - NonEmptyList(path, rest.toList) - .map { s => - val str = toS(s) - val pack = Parser.unsafeParse(Package.parser(None), str) - (("", LocationMap(str)), pack) - } - - val res = PackageMap.typeCheckParsed(packNEL, Nil, "") - res.left match { - case Some(err) => sys.error(err.toString) - case None => () - } - - res.right.get - } - def isfromString(s: String): InputStream = new ByteArrayInputStream(s.getBytes("UTF-8")) diff --git a/core/src/main/scala/org/bykn/bosatsu/FfiCall.scala b/core/src/main/scala/org/bykn/bosatsu/FfiCall.scala index 772642b5c..789a979d8 100644 --- a/core/src/main/scala/org/bykn/bosatsu/FfiCall.scala +++ b/core/src/main/scala/org/bykn/bosatsu/FfiCall.scala @@ -2,12 +2,12 @@ package org.bykn.bosatsu import cats.data.NonEmptyList -sealed abstract class FfiCall { +sealed abstract class FfiCall(val arity: Int) { def call(t: rankn.Type): Value } object FfiCall { - final case class Fn1(fn: Value => Value) extends FfiCall { + final case class Fn1(fn: Value => Value) extends FfiCall(1) { import Value.FnValue private[this] val evalFn: FnValue = FnValue { case NonEmptyList(a, _) => @@ -16,7 +16,7 @@ object FfiCall { def call(t: rankn.Type): Value = evalFn } - final case class Fn2(fn: (Value, Value) => Value) extends FfiCall { + final case class Fn2(fn: (Value, Value) => Value) extends FfiCall(2) { import Value.FnValue private[this] val evalFn: FnValue = @@ -26,7 +26,7 @@ object FfiCall { def call(t: rankn.Type): Value = evalFn } - final case class Fn3(fn: (Value, Value, Value) => Value) extends FfiCall { + final case class Fn3(fn: (Value, Value, Value) => Value) extends FfiCall(3) { import Value.FnValue private[this] val evalFn: FnValue = @@ -37,10 +37,6 @@ object FfiCall { def call(t: rankn.Type): Value = evalFn } - final case class FromFn(callFn: rankn.Type => Value) extends FfiCall { - def call(t: rankn.Type): Value = callFn(t) - } - def getJavaType(t: rankn.Type): List[Class[_]] = { def one(t: rankn.Type): Option[Class[_]] = loop(t, false) match { diff --git a/core/src/main/scala/org/bykn/bosatsu/MainModule.scala b/core/src/main/scala/org/bykn/bosatsu/MainModule.scala index 8f60ad42f..2fadd7efa 100644 --- a/core/src/main/scala/org/bykn/bosatsu/MainModule.scala +++ b/core/src/main/scala/org/bykn/bosatsu/MainModule.scala @@ -630,7 +630,7 @@ abstract class MainModule[IO[_]](implicit val intrinsic = PythonGen.intrinsicValues val missingExternals = allExternals.iterator.flatMap { case (p, names) => - val missing = names.filterNot { case n => + val missing = names.filterNot { case (n, _) => exts((p, n)) || intrinsic.get(p).exists(_(n)) } @@ -703,7 +703,7 @@ abstract class MainModule[IO[_]](implicit Doc.char('[') + Doc.intercalate( Doc.comma + Doc.lineOrSpace, - names.map(b => Doc.text(b.sourceCodeRepr)) + names.map { case (b, _) => Doc.text(b.sourceCodeRepr) } ) + Doc.char(']')).nested(4) } diff --git a/core/src/main/scala/org/bykn/bosatsu/PackageMap.scala b/core/src/main/scala/org/bykn/bosatsu/PackageMap.scala index aca53813e..29ad31053 100644 --- a/core/src/main/scala/org/bykn/bosatsu/PackageMap.scala +++ b/core/src/main/scala/org/bykn/bosatsu/PackageMap.scala @@ -45,13 +45,17 @@ case class PackageMap[A, B, C, +D]( def allExternals(implicit ev: Package[A, B, C, D] <:< Package.Typed[Any] - ): Map[PackageName, List[Identifier.Bindable]] = + ): Map[PackageName, List[(Identifier.Bindable, rankn.Type)]] = toMap.iterator.map { case (name, pack) => - (name, ev(pack).externalDefs) + val tpack = ev(pack) + (name, tpack.externalDefs.map { n => + (n, tpack.types.getExternalValue(name, n) + .getOrElse(sys.error(s"invariant violation, unknown type: $name $n")) ) + }) }.toMap def topoSort( - ev: Package[A, B, C, D] <:< Package.Typed[Any] + implicit ev: Package[A, B, C, D] <:< Package.Typed[Any] ): Toposort.Result[PackageName] = { val packNames = toMap.keys.iterator.toList.sorted diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala index 933d323c1..2867f9507 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala @@ -5,8 +5,8 @@ import cats.data.{StateT, EitherT, NonEmptyList, Chain} import java.math.BigInteger import java.nio.charset.StandardCharsets import org.bykn.bosatsu.codegen.Idents -import org.bykn.bosatsu.rankn.DataRepr -import org.bykn.bosatsu.{Identifier, Lit, Matchless, PackageName} +import org.bykn.bosatsu.rankn.{DataRepr, Type} +import org.bykn.bosatsu.{Identifier, Lit, Matchless, PackageName, PackageMap} import org.bykn.bosatsu.Matchless.Expr import org.bykn.bosatsu.Identifier.Bindable import org.typelevel.paiges.Doc @@ -21,9 +21,43 @@ object ClangGen { case class Unbound(bn: Bindable, inside: Option[(PackageName, Bindable)]) extends Error } + def generateExternalsStub(pm: PackageMap.Typed[Any]): Doc = { + val includes = Code.Include(true, "bosatsu_runtime.h") :: Nil + + def toStmt(pn: PackageName, ident: Identifier.Bindable, arity: Int): Code.Statement = { + val cIdent = generatedName(pn, ident) + val args = Idents.allSimpleIdents.take(arity).map { nm => + Code.Param(Code.TypeIdent.BValue, Code.Ident(nm)) + } + Code.DeclareFn(Nil, Code.TypeIdent.BValue, cIdent, args.toList, Some( + Code.block(Code.Return(Some(Code.IntLiteral.Zero))) + )) + } + + def tpeArity(t: Type): Int = + t match { + case Type.Fun.MaybeQuant(_, args, _) => args.length + case _ => 0 + } + + val fns = pm.allExternals + .iterator + .flatMap { case (p, vs) => + vs.iterator.map { case (n, tpe) => + Code.toDoc(toStmt(p, n, tpeArity(tpe))) + } + } + .toList + + val line2 = Doc.hardLine + Doc.hardLine + + Doc.intercalate(Doc.hardLine, includes.map(Code.toDoc)) + line2 + + Doc.intercalate(line2, fns) + } + def renderMain( sortedEnv: Vector[NonEmptyList[(PackageName, List[(Bindable, Expr)])]], - externals: Map[(PackageName, Bindable), (Code.Include, Code.Ident)], + externals: ((PackageName, Bindable)) => Option[(Code.Include, Code.Ident, Int)], value: (PackageName, Bindable), evaluator: (Code.Include, Code.Ident) ): Either[Error, Doc] = { @@ -44,7 +78,7 @@ object ClangGen { .iterator.flatMap(_.iterator) .flatMap { case (p, vs) => vs.iterator.map { case (b, e) => - (p, b) -> (e, Impl.generatedName(p, b)) + (p, b) -> (e, generatedName(p, b)) } } .toMap @@ -52,15 +86,15 @@ object ClangGen { run(allValues, externals, res) } - private object Impl { - type AllValues = Map[(PackageName, Bindable), (Expr, Code.Ident)] - type Externals = Map[(PackageName, Bindable), (Code.Include, Code.Ident)] + private def fullName(p: PackageName, b: Bindable): String = + p.asString + "/" + b.asString - def fullName(p: PackageName, b: Bindable): String = - p.asString + "/" + b.asString + def generatedName(p: PackageName, b: Bindable): Code.Ident = + Code.Ident(Idents.escape("___bsts_g_", fullName(p, b))) - def generatedName(p: PackageName, b: Bindable): Code.Ident = - Code.Ident(Idents.escape("___bsts_g_", fullName(p, b))) + private object Impl { + type AllValues = Map[(PackageName, Bindable), (Expr, Code.Ident)] + type Externals = Function1[(PackageName, Bindable), Option[(Code.Include, Code.Ident, Int)]] trait Env { import Matchless._ @@ -410,11 +444,7 @@ object ClangGen { case Some(nm) => pv(Code.Ident("STATIC_PUREFN")(nm)) case None => - // read_or_build(&__bvalue_foo, make_foo); - for { - value <- staticValueName(pack, name) - consFn <- constructorFn(pack, name) - } yield Code.Ident("read_or_build")(value.addr, consFn): Code.ValueLike + globalIdent(pack, name).map { nm => nm() } } case Local(arg) => directFn(arg) @@ -494,7 +524,7 @@ object ClangGen { case ZeroNat => pv(Code.Ident("BSTS_NAT_0")) case SuccNat => - val arg = Identifier.Name("arg0") + val arg = Identifier.Name("nat") // This relies on optimizing App(SuccNat, _) otherwise // it creates an infinite loop. // Also, this we should cache creation of Lambda/Closure values @@ -567,18 +597,21 @@ object ClangGen { _ <- appendStatement(stmt) } yield () case someValue => + // TODO: if we can create the value statically, we don't + // need the read_or_build trick + // // we materialize an Atomic value to hold the static data // then we generate a function to populate the value for { vl <- innerToValue(someValue) value <- staticValueName(p, b) - consFn <- constructorFn(p, b) _ <- appendStatement(Code.DeclareVar( Code.Attr.Static :: Nil, Code.TypeIdent.AtomicBValue, value, Some(Code.IntLiteral.Zero) )) + consFn <- constructorFn(p, b) _ <- appendStatement(Code.DeclareFn( Code.Attr.Static :: Nil, Code.TypeIdent.BValue, @@ -586,6 +619,15 @@ object ClangGen { Nil, Some(Code.block(Code.returnValue(vl))) )) + readFn <- globalIdent(p, b) + res = Code.Ident("read_or_build")(value.addr, consFn) + _ <- appendStatement(Code.DeclareFn( + Code.Attr.Static :: Nil, + Code.TypeIdent.BValue, + readFn, + Nil, + Some(Code.block(Code.returnValue(res))) + )) } yield () } } @@ -652,8 +694,9 @@ object ClangGen { def globalIdent(pn: PackageName, bn: Bindable): T[Code.Ident] = StateT { s => val key = (pn, bn) - s.externals.get(key) match { - case Some((incl, ident)) => + s.externals(key) match { + case Some((incl, ident, _)) => + // TODO: suspect that we are ignoring arity here val withIncl = if (s.includeSet(incl)) s else s.copy(includeSet = s.includeSet + incl, includes = s.includes :+ incl) @@ -775,9 +818,21 @@ object ClangGen { // record that this name is a top level function, so applying it can be direct def directFn(pack: PackageName, b: Bindable): T[Option[Code.Ident]] = StateT { s => - s.allValues.get((pack, b)) match { + val key = (pack, b) + s.allValues.get(key) match { case Some((_: Matchless.FnExpr, ident)) => result(s, Some(ident)) + case None => + // this is external + s.externals(key) match { + case Some((incl, ident, arity)) if arity > 0 => + // TODO: suspect that we are ignoring arity here + val withIncl = + if (s.includeSet(incl)) s + else s.copy(includeSet = s.includeSet + incl, includes = s.includes :+ incl) + result(withIncl, Some(ident)) + case _ => result(s, None) + } case _ => result(s, None) } } diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala index 565be64b4..2992894f7 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala @@ -401,8 +401,8 @@ object Code { private val doDoc = Doc.text("do ") private val whileDoc = Doc.text("while") private val arrow = Doc.text("->") - private val questionDoc = Doc.text(" ? ") - private val colonDoc = Doc.text(" : ") + private val questionDoc = Doc.text(" ?") + Doc.line + private val colonDoc = Doc.text(" :") + Doc.line private val quoteDoc = Doc.char('"') private def leftAngleDoc = BinOp.Lt.toDoc private def rightAngleDoc = BinOp.Gt.toDoc @@ -506,7 +506,7 @@ object Code { case noPar @ (Tight(_) | PrefixExpr(_, _) | BinExpr(_, _, _)) => toDoc(noPar) case yesPar => par(toDoc(yesPar)) } - d(cond) + questionDoc + d(t) + colonDoc + d(f) + (d(cond) + (questionDoc + d(t) + colonDoc + d(f)).nested(4)).grouped // Statements case Assignment(t, v) => toDoc(t) + (equalsDoc + (toDoc(v) + semiDoc)) case DeclareArray(tpe, nm, values) => diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala index 39c21f860..0320ee14f 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala @@ -923,6 +923,15 @@ object Type { else None } + object MaybeQuant { + def unapply(t: Type): Option[(Option[Quantification], NonEmptyList[Type], Type)] = + t match { + case Quantified(quant, Fun(args, res)) => Some((Some(quant), args, res)) + case Fun(args, res) => Some((None, args, res)) + case _ => None + } + } + def unapply(t: Type): Option[(NonEmptyList[Type], Type)] = { def check( n: Int, diff --git a/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala b/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala index 9fb15c714..a608b093f 100644 --- a/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala +++ b/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala @@ -1,13 +1,15 @@ package org.bykn.bosatsu -import cats.data.{Ior, Validated} -import cats.implicits._ +import cats.data.{Ior, Validated, NonEmptyList} +import java.nio.file.{Files, Paths} import org.bykn.bosatsu.rankn._ import org.scalatest.{Assertion, Assertions} import Assertions.{succeed, fail} import IorMethods.IorExtension +import cats.syntax.all._ + object TestUtils { def parsedTypeEnvOf( @@ -128,6 +130,28 @@ object TestUtils { } } + def compileFile(path: String, rest: String*)(implicit ec: Par.EC): PackageMap.Typed[Any] = { + def toS(s: String): String = + new String(Files.readAllBytes(Paths.get(s)), "UTF-8") + + val packNEL = + NonEmptyList(path, rest.toList) + .map { s => + val str = toS(s) + val pack = Parser.unsafeParse(Package.parser(None), str) + (("", LocationMap(str)), pack) + } + + val res = PackageMap.typeCheckParsed(packNEL, Nil, "") + res.left match { + case Some(err) => sys.error(err.toString) + case None => () + } + + res.right.get + } + + def makeInputArgs(files: List[(Int, Any)]): List[String] = ("--package_root" :: Int.MaxValue.toString :: Nil) ::: files.flatMap { case (idx, _) => "--input" :: idx.toString :: Nil diff --git a/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala b/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala index b20432673..4cd00e10f 100644 --- a/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala @@ -1,16 +1,23 @@ package org.bykn.bosatsu.codegen.clang import cats.data.NonEmptyList -import org.bykn.bosatsu.codegen.Idents import org.bykn.bosatsu.{PackageName, TestUtils, Identifier, Predef} import Identifier.Name class ClangGenTest extends munit.FunSuite { val predef_c = Code.Include(true, "bosatsu_predef.h") - def predef(s: String) = + def predef(s: String, arity: Int) = (PackageName.PredefName -> Name(s)) -> (predef_c, - Code.Ident(Idents.escape("__bsts_predef_", s))) + ClangGen.generatedName(PackageName.PredefName, Name(s)), + arity) + + val jvmExternals = { + val ext = Predef.jvmExternals.toMap.iterator.map { case ((_, n), ffi) => predef(n, ffi.arity) } + .toMap[(PackageName, Identifier), (Code.Include, Code.Ident, Int)] + + { (pn: (PackageName, Identifier)) => ext.get(pn) } + } def assertPredefFns(fns: String*)(matches: String)(implicit loc: munit.Location) = TestUtils.checkMatchless(""" @@ -30,8 +37,7 @@ x = 1 sortedEnv = Vector( NonEmptyList.one(PackageName.PredefName -> matchlessMap(PackageName.PredefName)), ), - externals = - Predef.jvmExternals.toMap.keys.iterator.map { case (_, n) => predef(n) }.toMap, + externals = jvmExternals, value = (PackageName.PredefName, Identifier.Name(fns.last)), evaluator = (Code.Include(true, "eval.h"), Code.Ident("evaluator_run")) ) @@ -42,6 +48,12 @@ x = 1 } } + def md5HashToHex(content: String): String = { + val md = java.security.MessageDigest.getInstance("MD5") + val digest = md.digest(content.getBytes("UTF-8")) + digest.map("%02x".format(_)).mkString + } + test("check build_List") { assertPredefFns("build_List")("""#include "bosatsu_runtime.h" From 84fb8da8db94b0d4fe8332635b340c8c476b581a Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Tue, 19 Nov 2024 13:50:05 -1000 Subject: [PATCH 06/11] test compilation of generated c (#1261) --- .github/workflows/ci.yml | 17 ++ .gitignore | 2 +- c_runtime/Makefile | 8 +- c_runtime/bosatsu_runtime.h | 2 +- c_runtime/test.c | 3 + .../bosatsu/codegen/clang/ClangGenTest.scala | 21 +- .../scala/org/bykn/bosatsu/MainModule.scala | 152 ++------------ .../scala/org/bykn/bosatsu/Matchless.scala | 2 +- .../org/bykn/bosatsu/codegen/Transpiler.scala | 36 ++++ .../bykn/bosatsu/codegen/clang/ClangGen.scala | 189 +++++++++++++----- .../codegen/clang/ClangTranspiler.scala | 65 ++++++ .../org/bykn/bosatsu/codegen/clang/Code.scala | 10 +- .../codegen/python/PythonTranspiler.scala | 123 ++++++++++++ .../bosatsu/codegen/clang/ClangGenTest.scala | 16 +- 14 files changed, 414 insertions(+), 232 deletions(-) create mode 100644 c_runtime/test.c create mode 100644 core/src/main/scala/org/bykn/bosatsu/codegen/Transpiler.scala create mode 100644 core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangTranspiler.scala create mode 100644 core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonTranspiler.scala diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3df372a42..041363736 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -110,12 +110,29 @@ jobs: - '8' testC: runs-on: ubuntu-latest + strategy: + matrix: + scala: + - '2.13.15' + java: + - '8' steps: - uses: "actions/checkout@v2.1.0" - name: "test runtime code" run: | cd c_runtime make && git diff --quiet + ./test + cd .. + - name: "build assembly" + run: "sbt \"++${{matrix.scala}}; cli/assembly\"" + - name: "generate c code" + run: "./bosatsuj transpile --input_dir test_workspace/ --package_root test_workspace/ --lang c --outdir c_out" + - name: "compile generated c code" + run: | + cp c_runtime/* c_out + cd c_out + gcc -c output.c timeout-minutes: 30 name: ci on: diff --git a/.gitignore b/.gitignore index 1f3cd8628..30af009f1 100644 --- a/.gitignore +++ b/.gitignore @@ -47,4 +47,4 @@ project/metals.sbt project/project/* jsui/bosatsu_ui.js -c_runtime/bosatsu_runtime.o \ No newline at end of file +c_runtime/*.o \ No newline at end of file diff --git a/c_runtime/Makefile b/c_runtime/Makefile index fdbaab144..420b339a1 100644 --- a/c_runtime/Makefile +++ b/c_runtime/Makefile @@ -1,4 +1,4 @@ -all: bosatsu_runtime.o +all: bosatsu_runtime.o test bosatsu_generated.h: typegen.py python3 typegen.py impls > bosatsu_generated.h @@ -7,4 +7,8 @@ bosatsu_decls_generated.h: typegen.py python3 typegen.py headers > bosatsu_decls_generated.h bosatsu_runtime.o: bosatsu_runtime.h bosatsu_runtime.c bosatsu_decls_generated.h bosatsu_generated.h - gcc -c bosatsu_runtime.c \ No newline at end of file + gcc -c bosatsu_runtime.c + +# this will eventually have test code for the runtime and predef +test: test.c + gcc -O3 -o test test.c diff --git a/c_runtime/bosatsu_runtime.h b/c_runtime/bosatsu_runtime.h index 59bde9fdd..702ede9be 100644 --- a/c_runtime/bosatsu_runtime.h +++ b/c_runtime/bosatsu_runtime.h @@ -95,7 +95,7 @@ BValue bsts_string_from_utf8_bytes_copy(size_t len, char* bytes); _Bool bsts_equals_string(BValue left, BValue right); BValue bsts_integer_from_int(int small_int); -BValue bsts_integer_from_words_copy(_Bool is_pos, size_t size, int32_t* words); +BValue bsts_integer_from_words_copy(_Bool is_pos, size_t size, uint32_t* words); _Bool bsts_equals_int(BValue left, BValue right); BValue alloc_external(void* eval, FreeFn free_fn); diff --git a/c_runtime/test.c b/c_runtime/test.c new file mode 100644 index 000000000..9825f5741 --- /dev/null +++ b/c_runtime/test.c @@ -0,0 +1,3 @@ +int main(int argc, char** argv) { + return 0; +} \ No newline at end of file diff --git a/cli/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala b/cli/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala index 11583c22e..066a3c77d 100644 --- a/cli/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala +++ b/cli/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala @@ -1,27 +1,12 @@ package org.bykn.bosatsu.codegen.clang import cats.data.NonEmptyList -import org.bykn.bosatsu.{PackageName, PackageMap, TestUtils, Identifier, Predef} -import Identifier.Name +import org.bykn.bosatsu.{PackageName, PackageMap, TestUtils, Identifier} import org.bykn.bosatsu.MatchlessFromTypedExpr import org.bykn.bosatsu.DirectEC.directEC class ClangGenTest extends munit.FunSuite { - val predef_c = Code.Include(true, "bosatsu_predef.h") - - def predef(s: String, arity: Int) = - (PackageName.PredefName -> Name(s)) -> (predef_c, - ClangGen.generatedName(PackageName.PredefName, Name(s)), - arity) - - val jvmExternals = { - val ext = Predef.jvmExternals.toMap.iterator.map { case ((_, n), ffi) => predef(n, ffi.arity) } - .toMap[(PackageName, Identifier), (Code.Include, Code.Ident, Int)] - - { (pn: (PackageName, Identifier)) => ext.get(pn) } - } - def md5HashToHex(content: String): String = { val md = java.security.MessageDigest.getInstance("MD5") val digest = md.digest(content.getBytes("UTF-8")) @@ -42,7 +27,7 @@ class ClangGenTest extends munit.FunSuite { val res = ClangGen.renderMain( sortedEnv = sortedEnv, - externals = jvmExternals, + externals = ClangGen.ExternalResolver.FromJvmExternals, value = (PackageName.PredefName, Identifier.Name("ignored")), evaluator = (Code.Include(true, "eval.h"), Code.Ident("evaluator_run")) ) @@ -61,7 +46,7 @@ class ClangGenTest extends munit.FunSuite { To inspect the code, change the hash, and it will print the code out */ testFilesCompilesToHash("test_workspace/Ackermann.bosatsu")( - "46716ef3c97cf2a79bf17d4033d55854" + "07c2ef3320a86aafc048aa9f7e3adc4d" ) } } \ No newline at end of file diff --git a/core/src/main/scala/org/bykn/bosatsu/MainModule.scala b/core/src/main/scala/org/bykn/bosatsu/MainModule.scala index 2fadd7efa..a2dd454fe 100644 --- a/core/src/main/scala/org/bykn/bosatsu/MainModule.scala +++ b/core/src/main/scala/org/bykn/bosatsu/MainModule.scala @@ -7,7 +7,8 @@ import cats.parse.{Parser0 => P0, Parser => P} import org.typelevel.paiges.Doc import scala.util.{Failure, Success, Try} -import CollectionUtils.listToUnique +import codegen.Transpiler + import Identifier.Bindable import IorMethods.IorExtension import LocationMap.Colorize @@ -599,144 +600,16 @@ abstract class MainModule[IO[_]](implicit } } - sealed abstract class Transpiler(val name: String) { - def renderAll( - pm: PackageMap.Typed[Any], - externals: List[String], - evaluators: List[String] - )(implicit ec: Par.EC): IO[List[(NonEmptyList[String], Doc)]] - } - object Transpiler { - case object PythonTranspiler extends Transpiler("python") { - def renderAll( - pm: PackageMap.Typed[Any], - externals: List[String], - evaluators: List[String] - )(implicit ec: Par.EC): IO[List[(NonEmptyList[String], Doc)]] = { - import codegen.python.PythonGen - - val allExternals = pm.allExternals - val cmp = MatchlessFromTypedExpr.compile(pm) - moduleIOMonad.catchNonFatal { - val parsedExt = - externals.map(Parser.unsafeParse(PythonGen.externalParser, _)) - val extMap = listToUnique(parsedExt.flatten)( - { case (p, b, _, _) => (p, b) }, - { case (_, _, m, f) => (m, f) }, - "expected each package/name to map to just one file" - ).get - - val exts = extMap.keySet - val intrinsic = PythonGen.intrinsicValues - val missingExternals = - allExternals.iterator.flatMap { case (p, names) => - val missing = names.filterNot { case (n, _) => - exts((p, n)) || intrinsic.get(p).exists(_(n)) - } - - if (missing.isEmpty) Nil - else (p, missing.sorted) :: Nil - }.toList - - if (missingExternals.isEmpty) { - val tests = pm.toMap.iterator.flatMap { case (n, pack) => - Package.testValue(pack).iterator.map { case (bn, _, _) => - (n, bn) - } - }.toMap - - val parsedEvals = - evaluators.map(Parser.unsafeParse(PythonGen.evaluatorParser, _)) - // TODO, we don't check that these types even exist in the fully - // universe, we should. If you have a typo in a type or package name - // you just get silently ignored - val typeEvalMap = listToUnique(parsedEvals.flatten)( - t => t._1, - t => t._2, - "expected each type to have to just one evaluator" - ).get - - val evalMap = pm.toMap.iterator.flatMap { case (n, p) => - val optEval = p.lets.findLast { case (_, _, te) => - // TODO this should really e checking that te.getType <:< a key - // in the map. - typeEvalMap.contains(te.getType) - } - optEval.map { case (b, _, te) => - val (m, i) = typeEvalMap(te.getType) - (n, (b, m, i)) - } - }.toMap + val transOpt: Opts[Transpiler] = { + val allTrans: List[Transpiler] = + codegen.python.PythonTranspiler :: + codegen.clang.ClangTranspiler :: + Nil - val docs = PythonGen - .renderAll(cmp, extMap, tests, evalMap) - .iterator - .map { case (_, (path, doc)) => - (path.map(_.name), doc) - } - .toList - - // python also needs empty __init__.py files in every parent directory - def prefixes[A]( - paths: List[(NonEmptyList[String], A)] - ): List[(NonEmptyList[String], Doc)] = { - val inits = - paths.map { case (path, _) => - val parent = path.init - val initPy = parent :+ "__init__.py" - NonEmptyList.fromListUnsafe(initPy) - }.toSet - - inits.toList.sorted.map(p => (p, Doc.empty)) - } - - prefixes(docs) ::: docs - } else { - // we need to render this nicer - val missingDoc = - missingExternals - .sortBy(_._1) - .map { case (p, names) => - (Doc.text("package") + Doc.lineOrSpace + Doc.text( - p.asString - ) + Doc.lineOrSpace + - Doc.char('[') + - Doc.intercalate( - Doc.comma + Doc.lineOrSpace, - names.map { case (b, _) => Doc.text(b.sourceCodeRepr) } - ) + Doc.char(']')).nested(4) - } - - val message = Doc.text( - "Missing external values:" - ) + (Doc.line + Doc.intercalate(Doc.line, missingDoc)).nested(4) - - throw new IllegalArgumentException(message.renderTrim(80)) - } - } - } - } - - val all: List[Transpiler] = List(PythonTranspiler) - - implicit def argumentForTranspiler: Argument[Transpiler] = - new Argument[Transpiler] { - val nameTo = all.iterator.map(t => (t.name, t)).toMap - - def defaultMetavar: String = "transpiler" - def read(string: String): ValidatedNel[String, Transpiler] = - nameTo.get(string) match { - case Some(t) => Validated.valid(t) - case None => - val keys = nameTo.keys.toList.sorted.mkString(",") - Validated.invalidNel( - s"unknown transpiler: $string, expected one of: $keys" - ) - } - } + implicit val arg = Transpiler.argumentFromTranspilers(allTrans) - val opt: Opts[Transpiler] = - Opts.option[Transpiler]("lang", "language to transpile to") + Opts.option[Transpiler]("lang", + s"language to transpile to (${allTrans.map(_.name).sorted.mkString(", ")})") } sealed abstract class JsonInput { @@ -1007,7 +880,8 @@ abstract class MainModule[IO[_]](implicit (packs, names) = pn extStrs <- exts.traverse(readPath) evalStrs <- evals.traverse(readPath) - data <- generator.renderAll(packs, extStrs, evalStrs) + dataTry = generator.renderAll(packs, extStrs, evalStrs) + data <- moduleIOMonad.fromTry(dataTry) } yield Output.TranspileOut(data, outDir) } } @@ -1503,7 +1377,7 @@ abstract class MainModule[IO[_]](implicit val transpileOpt = ( Inputs.runtimeOpts, colorOpt, - Transpiler.opt, + transOpt, Opts.option[Path]( "outdir", help = "directory to write all output into" diff --git a/core/src/main/scala/org/bykn/bosatsu/Matchless.scala b/core/src/main/scala/org/bykn/bosatsu/Matchless.scala index dac23c593..caee42805 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Matchless.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Matchless.scala @@ -334,7 +334,7 @@ object Matchless { .map { case (b, idx) => (b, ClosureSlot(idx)) } .toMap val captures = frees.flatMap { f => - if (f != n) (apply(f) :: Nil) + if (cats.Eq[Identifier].neqv(f, n)) (apply(f) :: Nil) else Nil } (copy(slots = newSlots), captures) diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/Transpiler.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/Transpiler.scala new file mode 100644 index 000000000..b89c7116b --- /dev/null +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/Transpiler.scala @@ -0,0 +1,36 @@ +package org.bykn.bosatsu.codegen + +import cats.data.{NonEmptyList, ValidatedNel, Validated} +import com.monovore.decline.Argument +import org.bykn.bosatsu.{PackageMap, Par} +import org.typelevel.paiges.Doc +import scala.util.Try + +trait Transpiler { + def name: String + + def renderAll( + pm: PackageMap.Typed[Any], + externals: List[String], + evaluators: List[String] + )(implicit ec: Par.EC): Try[List[(NonEmptyList[String], Doc)]] +} + +object Transpiler { + def argumentFromTranspilers(all: List[Transpiler]): Argument[Transpiler] = + new Argument[Transpiler] { + val nameTo = all.iterator.map(t => (t.name, t)).toMap + lazy val keys = nameTo.keys.toList.sorted.mkString(",") + + def defaultMetavar: String = "transpiler" + def read(string: String): ValidatedNel[String, Transpiler] = + nameTo.get(string) match { + case Some(t) => Validated.valid(t) + case None => + Validated.invalidNel( + s"unknown transpiler: $string, expected one of: $keys" + ) + } + } + +} \ No newline at end of file diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala index 2867f9507..8251e8989 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala @@ -6,58 +6,142 @@ import java.math.BigInteger import java.nio.charset.StandardCharsets import org.bykn.bosatsu.codegen.Idents import org.bykn.bosatsu.rankn.{DataRepr, Type} -import org.bykn.bosatsu.{Identifier, Lit, Matchless, PackageName, PackageMap} +import org.bykn.bosatsu.{Identifier, Lit, Matchless, Predef, PackageName, PackageMap} import org.bykn.bosatsu.Matchless.Expr import org.bykn.bosatsu.Identifier.Bindable import org.typelevel.paiges.Doc +import scala.collection.immutable.{SortedMap, SortedSet} import cats.syntax.all._ object ClangGen { - sealed abstract class Error + sealed abstract class Error { + // TODO: implement this in a nice way + def display: Doc = Doc.text(this.toString) + } + object Error { case class UnknownValue(pack: PackageName, value: Bindable) extends Error case class InvariantViolation(message: String, expr: Expr) extends Error case class Unbound(bn: Bindable, inside: Option[(PackageName, Bindable)]) extends Error } - def generateExternalsStub(pm: PackageMap.Typed[Any]): Doc = { - val includes = Code.Include(true, "bosatsu_runtime.h") :: Nil + trait ExternalResolver { + def names: SortedMap[PackageName, SortedSet[Bindable]] + def apply(p: PackageName, b: Bindable): Option[(Code.Include, Code.Ident, Int)] + + final def generateExternalsStub: SortedMap[String, Doc] = { + val includes = Code.Include(true, "bosatsu_runtime.h") :: Nil - def toStmt(pn: PackageName, ident: Identifier.Bindable, arity: Int): Code.Statement = { - val cIdent = generatedName(pn, ident) - val args = Idents.allSimpleIdents.take(arity).map { nm => - Code.Param(Code.TypeIdent.BValue, Code.Ident(nm)) + def toStmt(cIdent: Code.Ident, arity: Int): Code.Statement = { + val args = Idents.allSimpleIdents.take(arity).map { nm => + Code.Param(Code.TypeIdent.BValue, Code.Ident(nm)) + } + Code.DeclareFn(Nil, Code.TypeIdent.BValue, cIdent, args.toList, None) } - Code.DeclareFn(Nil, Code.TypeIdent.BValue, cIdent, args.toList, Some( - Code.block(Code.Return(Some(Code.IntLiteral.Zero))) - )) + + val line2 = Doc.hardLine + Doc.hardLine + val includeRuntime = Doc.intercalate(Doc.hardLine, includes.map(Code.toDoc)) + + SortedMap.empty[String, Doc] ++ names + .iterator + .flatMap { case (p, binds) => + + val fns = binds.iterator.flatMap { n => + apply(p, n) + } + .map { case (i, n, arity) => + i.filename -> Code.toDoc(toStmt(n, arity)) + } + .toList + .groupByNel(_._1) + + + fns.iterator.map { case (incl, nelInner) => + incl -> ( + includeRuntime + + line2 + + Doc.intercalate(line2, nelInner.toList.map(_._2))) + } + } } - def tpeArity(t: Type): Int = - t match { - case Type.Fun.MaybeQuant(_, args, _) => args.length - case _ => 0 - } + } + + object ExternalResolver { + def stdExtFileName(pn: PackageName): String = + s"${Idents.escape("bosatsu_ext_", pn.asString)}.h" + + def stdExternals(pm: PackageMap.Typed[Any]): ExternalResolver = { + + def tpeArity(t: Type): Int = + t match { + case Type.Fun.MaybeQuant(_, args, _) => args.length + case _ => 0 + } + + val allExt = pm.allExternals + val extMap = allExt + .iterator + .map { case (p, vs) => + val fileName = ExternalResolver.stdExtFileName(p) + + val fns = vs.iterator.map { case (n, tpe) => + val cIdent = generatedName(p, n) + n -> (cIdent, tpeArity(tpe)) + } + .toMap - val fns = pm.allExternals - .iterator - .flatMap { case (p, vs) => - vs.iterator.map { case (n, tpe) => - Code.toDoc(toStmt(p, n, tpeArity(tpe))) + p -> (Code.Include(true, fileName), fns) } + .toMap + + new ExternalResolver { + lazy val names: SortedMap[PackageName, SortedSet[Bindable]] = + allExt.iterator.map { case (p, vs) => + p -> vs.iterator.map { case (b, _) => b }.to(SortedSet) + } + .to(SortedMap) + + def apply(p: PackageName, b: Bindable): Option[(Code.Include, Code.Ident, Int)] = + for { + (include, inner) <- extMap.get(p) + (ident, arity) <- inner.get(b) + } yield (include, ident, arity) } - .toList + } + + val FromJvmExternals: ExternalResolver = + new ExternalResolver { + val predef_c = Code.Include(true, stdExtFileName(PackageName.PredefName)) + + def predef(s: String, arity: Int) = + (PackageName.PredefName -> Identifier.Name(s)) -> (predef_c, + ClangGen.generatedName(PackageName.PredefName, Identifier.Name(s)), arity) + + val ext = Predef + .jvmExternals.toMap.iterator.map { case ((_, n), ffi) => + predef(n, ffi.arity) + } + .toMap[(PackageName, Identifier), (Code.Include, Code.Ident, Int)] - val line2 = Doc.hardLine + Doc.hardLine + lazy val names: SortedMap[PackageName, SortedSet[Bindable]] = { + val sm = Predef + .jvmExternals.toMap.iterator.map { case (pn, _) => pn } + .toList + .groupByNel(_._1) - Doc.intercalate(Doc.hardLine, includes.map(Code.toDoc)) + line2 + - Doc.intercalate(line2, fns) + sm.map { case (k, vs) => + (k, vs.toList.iterator.map { case (_, n) => Identifier.Name(n) }.to(SortedSet)) + } + } + def apply(p: PackageName, b: Bindable) = ext.get((p, b)) + } } def renderMain( sortedEnv: Vector[NonEmptyList[(PackageName, List[(Bindable, Expr)])]], - externals: ((PackageName, Bindable)) => Option[(Code.Include, Code.Ident, Int)], + externals: ExternalResolver, value: (PackageName, Bindable), evaluator: (Code.Include, Code.Ident) ): Either[Error, Doc] = { @@ -94,14 +178,13 @@ object ClangGen { private object Impl { type AllValues = Map[(PackageName, Bindable), (Expr, Code.Ident)] - type Externals = Function1[(PackageName, Bindable), Option[(Code.Include, Code.Ident, Int)]] trait Env { import Matchless._ type T[A] implicit val monadImpl: Monad[T] - def run(pm: AllValues, externals: Externals, t: T[Unit]): Either[Error, Doc] + def run(pm: AllValues, externals: ExternalResolver, t: T[Unit]): Either[Error, Doc] def appendStatement(stmt: Code.Statement): T[Unit] def error[A](e: => Error): T[A] def globalIdent(pn: PackageName, bn: Bindable): T[Code.Ident] @@ -420,23 +503,28 @@ object ClangGen { expr match { case fn: FnExpr => innerFn(fn) case Let(Right(arg), argV, in) => - bind(arg) { - for { - name <- getBinding(arg) - v <- innerToValue(argV) - result <- innerToValue(in) - stmt <- Code.ValueLike.declareVar(name, Code.TypeIdent.BValue, v)(newLocalName) - } yield stmt +: result + // arg isn't in scope for argV + innerToValue(argV).flatMap { v => + bind(arg) { + for { + name <- getBinding(arg) + result <- innerToValue(in) + stmt <- Code.ValueLike.declareVar(name, Code.TypeIdent.BValue, v)(newLocalName) + } yield stmt +: result + } } case Let(Left(LocalAnon(idx)), argV, in) => - bindAnon(idx) { - for { - name <- getAnon(idx) - v <- innerToValue(argV) - result <- innerToValue(in) - stmt <- Code.ValueLike.declareVar(name, Code.TypeIdent.BValue, v)(newLocalName) - } yield stmt +: result - } + // LocalAnon(idx) isn't in scope for argV + innerToValue(argV) + .flatMap { v => + bindAnon(idx) { + for { + name <- getAnon(idx) + result <- innerToValue(in) + stmt <- Code.ValueLike.declareVar(name, Code.TypeIdent.BValue, v)(newLocalName) + } yield stmt +: result + } + } case app @ App(_, _) => innerApp(app) case Global(pack, name) => directFn(pack, name) @@ -515,7 +603,7 @@ object ClangGen { } case MakeStruct(arity) => pv { - if (arity == 0) Code.Ident("PURE_VALUE_TAG") + if (arity == 0) Code.Ident("PURE_VALUE_TAG").castTo(Code.TypeIdent.BValue) else { val allocStructFn = s"alloc_struct$arity" Code.Ident("STATIC_PUREFN")(Code.Ident(allocStructFn)) @@ -642,7 +730,7 @@ object ClangGen { new Env { case class State( allValues: AllValues, - externals: Externals, + externals: ExternalResolver, includeSet: Set[Code.Include], includes: Chain[Code.Include], stmts: Chain[Code.Statement], @@ -657,7 +745,7 @@ object ClangGen { } object State { - def init(allValues: AllValues, externals: Externals): State = { + def init(allValues: AllValues, externals: ExternalResolver): State = { val defaultIncludes = List(Code.Include(true, "bosatsu_runtime.h")) @@ -671,7 +759,7 @@ object ClangGen { implicit val monadImpl: Monad[T] = catsMonad[State] - def run(pm: AllValues, externals: Externals, t: T[Unit]): Either[Error, Doc] = + def run(pm: AllValues, externals: ExternalResolver, t: T[Unit]): Either[Error, Doc] = t.run(State.init(pm, externals)) .value // get the value out of the EitherT .value // evaluate the Eval @@ -693,8 +781,7 @@ object ClangGen { def globalIdent(pn: PackageName, bn: Bindable): T[Code.Ident] = StateT { s => - val key = (pn, bn) - s.externals(key) match { + s.externals(pn, bn) match { case Some((incl, ident, _)) => // TODO: suspect that we are ignoring arity here val withIncl = @@ -703,6 +790,7 @@ object ClangGen { result(withIncl, ident) case None => + val key = (pn, bn) s.allValues.get(key) match { case Some((_, ident)) => result(s, ident) case None => errorRes(Error.UnknownValue(pn, bn)) @@ -824,9 +912,8 @@ object ClangGen { result(s, Some(ident)) case None => // this is external - s.externals(key) match { + s.externals(pack, b) match { case Some((incl, ident, arity)) if arity > 0 => - // TODO: suspect that we are ignoring arity here val withIncl = if (s.includeSet(incl)) s else s.copy(includeSet = s.includeSet + incl, includes = s.includes :+ incl) diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangTranspiler.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangTranspiler.scala new file mode 100644 index 000000000..07a1dedc2 --- /dev/null +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangTranspiler.scala @@ -0,0 +1,65 @@ +package org.bykn.bosatsu.codegen.clang + +import cats.data.NonEmptyList +import org.bykn.bosatsu.codegen.Transpiler +import org.bykn.bosatsu.{Identifier, MatchlessFromTypedExpr, PackageName, PackageMap, Par} +import org.typelevel.paiges.Doc +import scala.util.{Failure, Success, Try} + +case object ClangTranspiler extends Transpiler { + case class GenError(error: ClangGen.Error) extends Exception(s"clang gen error: ${error.display.render(80)}") + + case class CircularPackagesFound(loop: NonEmptyList[PackageName]) + extends Exception(s"circular dependencies found in packages: ${ + loop.map(_.asString).toList.mkString(", ") + }") + + def name = "c" + + def externalsFor(pm: PackageMap.Typed[Any], arg: List[String]): ClangGen.ExternalResolver = + // TODO: we shouldn't take the arg at the higher level, but customizing args + // for backends hasn't be implemented yet + ClangGen.ExternalResolver.stdExternals(pm) + + def renderAll( + pm: PackageMap.Typed[Any], + externals: List[String], + evaluators: List[String] + )(implicit ec: Par.EC): Try[List[(NonEmptyList[String], Doc)]] = { + // we have to render the code in sorted order + val sorted = pm.topoSort + NonEmptyList.fromList(sorted.loopNodes) match { + case Some(loop) => Failure(CircularPackagesFound(loop)) + case None => + val matchlessMap = MatchlessFromTypedExpr.compile(pm) + + val ext = externalsFor(pm, externals) + val doc = ClangGen.renderMain( + sortedEnv = cats.Functor[Vector] + .compose[NonEmptyList] + .map(sorted.layers) { pn => pn -> matchlessMap(pn) }, + externals = ext, + // TODO: this is currently ignored + value = (PackageName.PredefName, Identifier.Name("todo")), + // TODO: this is also ignored currently + evaluator = (Code.Include(true, "eval.h"), Code.Ident("evaluator_run")) + ) + + doc match { + case Left(err) => Failure(GenError(err)) + case Right(doc) => + // TODO: this name needs to be an option + val outputName = NonEmptyList("output.c", Nil) + + val externalHeaders = ext.generateExternalsStub + .iterator.map { case (n, d) => + NonEmptyList.one(n) -> d + } + .toList + + // TODO: always outputing the headers may not be right, maybe an option + Success((outputName -> doc) :: externalHeaders) + } + } + } +} \ No newline at end of file diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala index 2992894f7..cb0d1d3f0 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala @@ -133,6 +133,8 @@ object Code { def stmt: Statement = Effect(this) + def castTo(tpe: TypeIdent): Expression = Cast(tpe, this) + def bin(op: BinOp, rhs: Expression): Expression = BinExpr(this, op, rhs) @@ -439,15 +441,15 @@ object Code { val bytes = str.getBytes(StandardCharsets.US_ASCII) bytes.foreach { c => val cint = c.toInt & 0xFF - if (25 <= cint && cint <= 126) { - result.append(cint.toChar) - } - else if (cint == 92) { // this is \ + if (cint == 92) { // this is \ result.append("\\\\") } else if (cint == 34) { // this is " result.append("\\\"") } + else if (25 <= cint && cint <= 126) { + result.append(cint.toChar) + } else { result.append(s"\\x${java.lang.Integer.toHexString(cint)}") } diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonTranspiler.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonTranspiler.scala new file mode 100644 index 000000000..bbad6055d --- /dev/null +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonTranspiler.scala @@ -0,0 +1,123 @@ +package org.bykn.bosatsu.codegen.python + +import cats.data.NonEmptyList +import cats.implicits.catsKernelOrderingForOrder +import org.bykn.bosatsu.{Package, PackageMap, Par, Parser, MatchlessFromTypedExpr} +import org.bykn.bosatsu.codegen.Transpiler +import org.typelevel.paiges.Doc +import scala.util.Try + +import org.bykn.bosatsu.CollectionUtils.listToUnique + +case object PythonTranspiler extends Transpiler { + val name: String = "python" + + def renderAll( + pm: PackageMap.Typed[Any], + externals: List[String], + evaluators: List[String] + )(implicit ec: Par.EC): Try[List[(NonEmptyList[String], Doc)]] = { + + val cmp = MatchlessFromTypedExpr.compile(pm) + Try { + val parsedExt = + externals.map(Parser.unsafeParse(PythonGen.externalParser, _)) + + val extMap = listToUnique(parsedExt.flatten)( + { case (p, b, _, _) => (p, b) }, + { case (_, _, m, f) => (m, f) }, + "expected each package/name to map to just one file" + ).get + + val exts = extMap.keySet + val intrinsic = PythonGen.intrinsicValues + + val allExternals = pm.allExternals + val missingExternals = + allExternals.iterator.flatMap { case (p, names) => + val missing = names.filterNot { case (n, _) => + exts((p, n)) || intrinsic.get(p).exists(_(n)) + } + + if (missing.isEmpty) Nil + else (p, missing.sorted) :: Nil + }.toList + + if (missingExternals.isEmpty) { + val tests = pm.toMap.iterator.flatMap { case (n, pack) => + Package.testValue(pack).iterator.map { case (bn, _, _) => + (n, bn) + } + }.toMap + + val parsedEvals = + evaluators.map(Parser.unsafeParse(PythonGen.evaluatorParser, _)) + // TODO, we don't check that these types even exist in the fully + // universe, we should. If you have a typo in a type or package name + // you just get silently ignored + val typeEvalMap = listToUnique(parsedEvals.flatten)( + t => t._1, + t => t._2, + "expected each type to have to just one evaluator" + ).get + + val evalMap = pm.toMap.iterator.flatMap { case (n, p) => + val optEval = p.lets.findLast { case (_, _, te) => + // TODO this should really e checking that te.getType <:< a key + // in the map. + typeEvalMap.contains(te.getType) + } + optEval.map { case (b, _, te) => + val (m, i) = typeEvalMap(te.getType) + (n, (b, m, i)) + } + }.toMap + + val docs = PythonGen + .renderAll(cmp, extMap, tests, evalMap) + .iterator + .map { case (_, (path, doc)) => + (path.map(_.name), doc) + } + .toList + + // python also needs empty __init__.py files in every parent directory + def prefixes[A]( + paths: List[(NonEmptyList[String], A)] + ): List[(NonEmptyList[String], Doc)] = { + val inits = + paths.map { case (path, _) => + val parent = path.init + val initPy = parent :+ "__init__.py" + NonEmptyList.fromListUnsafe(initPy) + }.toSet + + inits.toList.sorted.map(p => (p, Doc.empty)) + } + + prefixes(docs) ::: docs + } else { + // we need to render this nicer + val missingDoc = + missingExternals + .sortBy(_._1) + .map { case (p, names) => + (Doc.text("package") + Doc.lineOrSpace + Doc.text( + p.asString + ) + Doc.lineOrSpace + + Doc.char('[') + + Doc.intercalate( + Doc.comma + Doc.lineOrSpace, + names.map { case (b, _) => Doc.text(b.sourceCodeRepr) } + ) + Doc.char(']')).nested(4) + } + + val message = Doc.text( + "Missing external values:" + ) + (Doc.line + Doc.intercalate(Doc.line, missingDoc)).nested(4) + + throw new IllegalArgumentException(message.renderTrim(80)) + } + } + } +} \ No newline at end of file diff --git a/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala b/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala index 4cd00e10f..8adebad0e 100644 --- a/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala @@ -5,20 +5,6 @@ import org.bykn.bosatsu.{PackageName, TestUtils, Identifier, Predef} import Identifier.Name class ClangGenTest extends munit.FunSuite { - val predef_c = Code.Include(true, "bosatsu_predef.h") - - def predef(s: String, arity: Int) = - (PackageName.PredefName -> Name(s)) -> (predef_c, - ClangGen.generatedName(PackageName.PredefName, Name(s)), - arity) - - val jvmExternals = { - val ext = Predef.jvmExternals.toMap.iterator.map { case ((_, n), ffi) => predef(n, ffi.arity) } - .toMap[(PackageName, Identifier), (Code.Include, Code.Ident, Int)] - - { (pn: (PackageName, Identifier)) => ext.get(pn) } - } - def assertPredefFns(fns: String*)(matches: String)(implicit loc: munit.Location) = TestUtils.checkMatchless(""" x = 1 @@ -37,7 +23,7 @@ x = 1 sortedEnv = Vector( NonEmptyList.one(PackageName.PredefName -> matchlessMap(PackageName.PredefName)), ), - externals = jvmExternals, + externals = ClangGen.ExternalResolver.FromJvmExternals, value = (PackageName.PredefName, Identifier.Name(fns.last)), evaluator = (Code.Include(true, "eval.h"), Code.Ident("evaluator_run")) ) From 35cbe08b7d94455d101449eeeb93e8948241a8a8 Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Tue, 19 Nov 2024 14:10:59 -1000 Subject: [PATCH 07/11] evaluate static string matches in TypedExprNormalization (#1262) * evaluate static string matches in matchless * refactor to try to move to TypedExprNormalization... * fix tests * close TODO by normalizing Lit Str and Chr patterns * simplify conversion to regex * clean up toRegex more --- .../src/main/scala/org/bykn/bosatsu/Lit.scala | 9 +- .../scala/org/bykn/bosatsu/Matchless.scala | 73 +----- .../org/bykn/bosatsu/MatchlessToValue.scala | 125 +--------- .../main/scala/org/bykn/bosatsu/Pattern.scala | 56 +++-- .../scala/org/bykn/bosatsu/StringUtil.scala | 19 ++ .../org/bykn/bosatsu/TotalityCheck.scala | 4 +- .../bykn/bosatsu/TypedExprNormalization.scala | 41 ++-- .../bosatsu/codegen/python/PythonGen.scala | 1 + .../org/bykn/bosatsu/pattern/Matcher.scala | 4 +- .../org/bykn/bosatsu/pattern/SeqPattern.scala | 2 +- .../org/bykn/bosatsu/pattern/Splitter.scala | 76 +++--- .../org/bykn/bosatsu/pattern/StrPart.scala | 221 ++++++++++++++++++ .../src/test/scala/org/bykn/bosatsu/Gen.scala | 32 ++- .../scala/org/bykn/bosatsu/ParserTest.scala | 24 +- .../bykn/bosatsu/pattern/SeqPatternTest.scala | 53 +++-- .../bykn/bosatsu/pattern/StrPartTest.scala | 103 ++++++++ .../pattern/StringSeqPatternSetLaws.scala | 57 ++--- 17 files changed, 570 insertions(+), 330 deletions(-) create mode 100644 core/src/main/scala/org/bykn/bosatsu/pattern/StrPart.scala create mode 100644 core/src/test/scala/org/bykn/bosatsu/pattern/StrPartTest.scala diff --git a/core/src/main/scala/org/bykn/bosatsu/Lit.scala b/core/src/main/scala/org/bykn/bosatsu/Lit.scala index b76039678..78882645a 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Lit.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Lit.scala @@ -45,10 +45,15 @@ object Lit { } else new Integer(BigInteger.valueOf(l)) } - case class Str(toStr: String) extends Lit { + // Means this lit could be the result of a string match + sealed abstract class StringMatchResult extends Lit { + def asStr: String + } + case class Str(toStr: String) extends StringMatchResult { def unboxToAny: Any = toStr + def asStr = toStr } - case class Chr(asStr: String) extends Lit { + case class Chr(asStr: String) extends StringMatchResult { def toCodePoint: Int = asStr.codePointAt(0) def unboxToAny: Any = asStr } diff --git a/core/src/main/scala/org/bykn/bosatsu/Matchless.scala b/core/src/main/scala/org/bykn/bosatsu/Matchless.scala index caee42805..34a4f8a03 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Matchless.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Matchless.scala @@ -2,6 +2,7 @@ package org.bykn.bosatsu import cats.{Monad, Monoid} import cats.data.{Chain, NonEmptyList, WriterT} +import org.bykn.bosatsu.pattern.StrPart import org.bykn.bosatsu.rankn.{DataRepr, Type, RefSpace} import Identifier.{Bindable, Constructor} @@ -24,66 +25,6 @@ object Matchless { def body: Expr } - sealed abstract class StrPart - object StrPart { - sealed abstract class Glob(val capture: Boolean) extends StrPart - sealed abstract class CharPart(val capture: Boolean) extends StrPart - case object WildStr extends Glob(false) - case object IndexStr extends Glob(true) - case object WildChar extends CharPart(false) - case object IndexChar extends CharPart(true) - case class LitStr(asString: String) extends StrPart - - sealed abstract class MatchSize(val isExact: Boolean) { - def charCount: Int - def canMatch(cp: Int): Boolean - // we know chars/2 <= cpCount <= chars for utf16 - def canMatchUtf16Count(chars: Int): Boolean - } - object MatchSize { - case class Exactly(charCount: Int) extends MatchSize(true) { - def canMatch(cp: Int): Boolean = cp == charCount - def canMatchUtf16Count(chars: Int): Boolean = { - val cpmin = chars / 2 - val cpmax = chars - (cpmin <= charCount) && (charCount <= cpmax) - } - } - case class AtLeast(charCount: Int) extends MatchSize(false) { - def canMatch(cp: Int): Boolean = charCount <= cp - def canMatchUtf16Count(chars: Int): Boolean = { - val cpmax = chars - // we have any cp in [cpmin, cpmax] - // but we require charCount <= cp - (charCount <= cpmax) - } - } - - private val atLeast0 = AtLeast(0) - private val exactly0 = Exactly(0) - private val exactly1 = Exactly(1) - - def from(sp: StrPart): MatchSize = - sp match { - case _: Glob => atLeast0 - case _: CharPart => exactly1 - case LitStr(str) => - Exactly(str.codePointCount(0, str.length)) - } - - def apply[F[_]: cats.Foldable](f: F[StrPart]): MatchSize = - cats.Foldable[F].foldMap(f)(from) - - implicit val monoidMatchSize: Monoid[MatchSize] = - new Monoid[MatchSize] { - def empty: MatchSize = exactly0 - def combine(l: MatchSize, r: MatchSize) = - if (l.isExact && r.isExact) Exactly(l.charCount + r.charCount) - else AtLeast(l.charCount + r.charCount) - } - } - } - // name is set for recursive (but not tail recursive) methods case class Lambda( captures: List[Expr], @@ -527,10 +468,6 @@ object Matchless { case Pattern.StrPart.NamedChar(n) => n } - val muts = sbinds.traverse { b => - makeAnon.map(LocalAnonMut(_)).map((b, _)) - } - val pat = items.toList.map { case Pattern.StrPart.NamedStr(_) => StrPart.IndexStr case Pattern.StrPart.NamedChar(_) => StrPart.IndexChar @@ -539,10 +476,13 @@ object Matchless { case Pattern.StrPart.LitStr(s) => StrPart.LitStr(s) } - muts.map { binds => + sbinds.traverse { b => + makeAnon.map(LocalAnonMut(_)).map((b, _)) + } + .map { binds => val ms = binds.map(_._2) - NonEmptyList.of((ms, MatchString(arg, pat, ms), binds)) + NonEmptyList.one((ms, MatchString(arg, pat, ms), binds)) } case lp @ Pattern.ListPat(_) => lp.toPositionalStruct(empty, cons) match { @@ -817,7 +757,6 @@ object Matchless { (Pattern[(PackageName, Constructor), Type], Expr) ] ): F[Expr] = { - def recur( arg: CheapExpr, branches: NonEmptyList[ diff --git a/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala b/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala index f2a93e3ca..19da5db6b 100644 --- a/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala +++ b/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala @@ -4,14 +4,13 @@ import cats.{Eval, Functor, Applicative} import cats.data.NonEmptyList import cats.evidence.Is import java.math.BigInteger +import org.bykn.bosatsu.pattern.StrPart import scala.collection.immutable.LongMap import scala.collection.mutable.{LongMap => MLongMap} import Identifier.Bindable import Value._ -import cats.implicits._ - object MatchlessToValue { import Matchless._ @@ -209,7 +208,7 @@ object MatchlessToValue { // we have nothing to bind loop(str).map { strV => val arg = strV.asExternal.toAny.asInstanceOf[String] - matchString(arg, pat, 0) != null + StrPart.matchString(arg, pat, 0) != null } case _ => val bary = binds.iterator.collect { case LocalAnonMut(id) => @@ -219,7 +218,7 @@ object MatchlessToValue { // this may be static val matchScope = loop(str).map { str => val arg = str.asExternal.toAny.asInstanceOf[String] - matchString(arg, pat, bary.length) + StrPart.matchString(arg, pat, bary.length) } // if we mutate scope, it has to be dynamic Dynamic { scope => @@ -547,123 +546,5 @@ object MatchlessToValue { } } - - private[this] val emptyStringArray: Array[String] = new Array[String](0) - def matchString( - str: String, - pat: List[StrPart], - binds: Int - ): Array[String] = { - import Matchless.StrPart._ - - val strLen = str.length() - val results = - if (binds > 0) new Array[String](binds) else emptyStringArray - - def loop(offset: Int, pat: List[StrPart], next: Int): Boolean = - pat match { - case Nil => offset == strLen - case LitStr(expect) :: tail => - val len = expect.length - str.regionMatches(offset, expect, 0, len) && loop( - offset + len, - tail, - next - ) - case (c: CharPart) :: tail => - try { - val nextOffset = str.offsetByCodePoints(offset, 1) - val n = - if (c.capture) { - results(next) = str.substring(offset, nextOffset) - next + 1 - } else next - - loop(nextOffset, tail, n) - } catch { - case _: IndexOutOfBoundsException => false - } - case (h: Glob) :: tail => - tail match { - case Nil => - // we capture all the rest - if (h.capture) { - results(next) = str.substring(offset) - } - true - case rest @ ((_: CharPart) :: _) => - val matchableSizes = MatchSize[List](rest) - - def canMatch(off: Int): Boolean = - matchableSizes.canMatch(str.codePointCount(off, strLen)) - - // (.*)(.)tail2 - // this is a naive algorithm that just - // checks at all possible later offsets - // a smarter algorithm could see if there - // are Lit parts that can match or not - var matched = false - var off1 = offset - val n1 = if (h.capture) (next + 1) else next - while (!matched && (off1 < strLen)) { - matched = canMatch(off1) && loop(off1, rest, n1) - if (!matched) { - off1 = off1 + Character.charCount(str.codePointAt(off1)) - } - } - - matched && { - if (h.capture) { - results(next) = str.substring(offset, off1) - } - true - } - case LitStr(expect) :: tail2 => - val next1 = if (h.capture) next + 1 else next - - val matchableSizes = MatchSize(tail2) - - def canMatch(off: Int): Boolean = - matchableSizes.canMatchUtf16Count(strLen - off) - - var start = offset - var result = false - while (start >= 0) { - val candidate = str.indexOf(expect, start) - if (candidate >= 0) { - // we have to skip the current expect string - val nextOff = candidate + expect.length - val check1 = - canMatch(nextOff) && loop(nextOff, tail2, next1) - if (check1) { - // this was a match, write into next if needed - if (h.capture) { - results(next) = str.substring(offset, candidate) - } - result = true - start = -1 - } else { - // we couldn't match here, try just after candidate - start = candidate + Character.charCount( - str.codePointAt(candidate) - ) - } - } else { - // no more candidates - start = -1 - } - } - result - // $COVERAGE-OFF$ - case (_: Glob) :: _ => - // this should be an error at compile time since it - // is never meaningful to have two adjacent globs - sys.error(s"invariant violation, adjacent globs: $pat") - // $COVERAGE-ON$ - } - } - - if (loop(0, pat, 0)) results else null - } } } diff --git a/core/src/main/scala/org/bykn/bosatsu/Pattern.scala b/core/src/main/scala/org/bykn/bosatsu/Pattern.scala index 136245e45..eca2f87a2 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Pattern.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Pattern.scala @@ -5,6 +5,7 @@ import cats.data.NonEmptyList import cats.parse.{Parser0 => P0, Parser => P} import org.typelevel.paiges.{Doc, Document} import org.bykn.bosatsu.pattern.{NamedSeqPattern, SeqPattern, SeqPart} +import java.util.regex.{Pattern => RegexPattern} import Parser.{Combinators, maybeSpace, MaybeTupleOrParens} import cats.implicits._ @@ -464,10 +465,10 @@ object Pattern { } } - lazy val toNamedSeqPattern: NamedSeqPattern[Char] = + lazy val toNamedSeqPattern: NamedSeqPattern[Int] = StrPat.toNamedSeqPattern(this) - lazy val toSeqPattern: SeqPattern[Char] = toNamedSeqPattern.unname + lazy val toSeqPattern: SeqPattern[Int] = toNamedSeqPattern.unname lazy val toLiteralString: Option[String] = toSeqPattern.toLiteralSeq.map(_.mkString) @@ -476,6 +477,30 @@ object Pattern { def matches(str: String): Boolean = isTotal || matcher(str).isDefined + + /** + * Convert to a regular expression matching this pattern, which + * uses reluctant modifiers + */ + def toRegex: RegexPattern = { + def mapPart(p: StrPart): String = + p match { + case StrPart.NamedStr(_) => "(.*?)" + case StrPart.WildStr => ".*?" + case StrPart.NamedChar(_) => "(.)" + case StrPart.WildChar => "." + case StrPart.LitStr(s) => + // we need to escape any characters that may be in regex + RegexPattern.quote(s) + } + RegexPattern.compile( + parts + .iterator + .map(mapPart(_)) + .mkString, + RegexPattern.DOTALL + ) + } } /** Patterns like Some(_) as foo as binds tighter than |, so use ( ) with @@ -600,14 +625,19 @@ object Pattern { val Empty: StrPat = fromLitStr("") val Wild: StrPat = StrPat(NonEmptyList.one(StrPart.WildStr)) - def fromSeqPattern(sp: SeqPattern[Char]): StrPat = { - def lit(rev: List[Char]): List[StrPart.LitStr] = + def fromSeqPattern(sp: SeqPattern[Int]): StrPat = { + def lit(rev: List[Int]): List[StrPart.LitStr] = if (rev.isEmpty) Nil - else StrPart.LitStr(rev.reverse.mkString) :: Nil + else { + val cps = rev.reverse + val bldr = new java.lang.StringBuilder + cps.foreach(bldr.appendCodePoint(_)) + StrPart.LitStr(bldr.toString) :: Nil + } def loop( - ps: List[SeqPart[Char]], - front: List[Char] + ps: List[SeqPart[Int]], + front: List[Int] ): NonEmptyList[StrPart] = ps match { case Nil => NonEmptyList.fromList(lit(front)).getOrElse(Empty.parts) @@ -638,10 +668,10 @@ object Pattern { StrPat(loop(sp.toList, Nil)) } - def toNamedSeqPattern(sp: StrPat): NamedSeqPattern[Char] = { - val empty: NamedSeqPattern[Char] = NamedSeqPattern.NEmpty + def toNamedSeqPattern(sp: StrPat): NamedSeqPattern[Int] = { + val empty: NamedSeqPattern[Int] = NamedSeqPattern.NEmpty - def partToNsp(s: StrPart): NamedSeqPattern[Char] = + def partToNsp(s: StrPart): NamedSeqPattern[Int] = s match { case StrPart.NamedStr(n) => NamedSeqPattern.Bind(n.sourceCodeRepr, NamedSeqPattern.Wild) @@ -650,9 +680,9 @@ object Pattern { case StrPart.WildStr => NamedSeqPattern.Wild case StrPart.WildChar => NamedSeqPattern.Any case StrPart.LitStr(s) => - if (s.isEmpty) empty - else - s.toList.foldRight(empty) { (c, tail) => + StringUtil + .codePoints(s) + .foldRight(empty) { (c, tail) => NamedSeqPattern.NCat(NamedSeqPattern.fromLit(c), tail) } } diff --git a/core/src/main/scala/org/bykn/bosatsu/StringUtil.scala b/core/src/main/scala/org/bykn/bosatsu/StringUtil.scala index d8a14c060..aa56c3bc4 100644 --- a/core/src/main/scala/org/bykn/bosatsu/StringUtil.scala +++ b/core/src/main/scala/org/bykn/bosatsu/StringUtil.scala @@ -211,6 +211,25 @@ object StringUtil extends GenericStringUtil { ('t', '\t'), ('v', 11.toChar) ) // vertical tab + + def codePoints(s: String): List[Int] = { + // .codePoints isn't available in scalajs + var idx = 0 + val bldr = List.newBuilder[Int] + while (idx < s.length) { + val cp = s.codePointAt(idx) + idx += Character.charCount(cp) + bldr += cp + } + + bldr.result() + } + + def fromCodePoints(it: Iterable[Int]): String = { + val bldr = new java.lang.StringBuilder + it.foreach(bldr.appendCodePoint(_)) + bldr.toString + } } object JsonStringUtil extends GenericStringUtil { diff --git a/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala b/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala index b8c070a26..9d56f2102 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala @@ -237,9 +237,9 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { ) private val strPatternSetOps: SetOps[StrPat] = - SetOps.imap[SeqPattern[Char], StrPat]( + SetOps.imap[SeqPattern[Int], StrPat]( SeqPattern.seqPatternSetOps( - SeqPart.part1SetOps(SetOps.distinct[Char]), + SeqPart.part1SetOps(SetOps.distinct[Int]), implicitly ), StrPat.fromSeqPattern(_), diff --git a/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala b/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala index dc1d54c4e..8b01a4801 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala @@ -3,6 +3,7 @@ package org.bykn.bosatsu import cats.Foldable import cats.data.NonEmptyList import org.bykn.bosatsu.rankn.{Type, TypeEnv} +import org.bykn.bosatsu.pattern.StrPart import Identifier.{Bindable, Constructor} @@ -826,25 +827,33 @@ object TypedExprNormalization { } } - case EvalResult.Constant(li @ Lit.Integer(i)) => + case EvalResult.Constant(li) => + def makeLet( p: Pattern[(PackageName, Constructor), Type] - ): Option[List[Bindable]] = + ): Option[List[(Bindable, Lit)]] = p match { case Pattern.Named(v, p) => - makeLet(p).map(v :: _) + makeLet(p).map((v, li) :: _) case Pattern.WildCard => Some(Nil) - case Pattern.Var(v) => Some(v :: Nil) + case Pattern.Var(v) => + Some((v, li) :: Nil) case Pattern.Annotation(p, _) => makeLet(p) - case Pattern.Literal(Lit.Integer(j)) => - if (j == i) Some(Nil) + case Pattern.Literal(litj) => + if (li == litj) Some(Nil) else None case Pattern.Union(h, t) => (h :: t).toList.iterator.map(makeLet).reduce(_.orElse(_)) - // $COVERAGE-OFF$ this is ill-typed so should be unreachable - case Pattern.PositionalStruct(_, _) | Pattern.ListPat(_) | - Pattern.StrPat(_) | - Pattern.Literal(Lit.Str(_) | Lit.Chr(_)) => + case sp @ Pattern.StrPat(_) => + li match { + case Lit.Str(str) => + StrPart.matchPattern(str, sp) + // $COVERAGE-OFF$ these are ill-typed so should be unreachable + case _ => None + } + + case Pattern.PositionalStruct(_, _) | Pattern.ListPat(_) => + // None // $COVERAGE-ON$ } @@ -852,17 +861,13 @@ object TypedExprNormalization { Foldable[NonEmptyList] .collectFirstSome[Branch[A], TypedExpr[A]](m.branches) { case (p, r) => - makeLet(p).map { names => - val lit = Literal[A](li, Type.getTypeOf(li), m.tag) - // all these names are bound to the lit - names.distinct.foldLeft(r) { case (r, n) => - Let(n, lit, r, RecursionKind.NonRecursive, m.tag) + makeLet(p).map { binds => + binds.foldRight(r) { case ((n, li), r) => + val te = Literal[A](li, Type.getTypeOf(li), m.arg.tag) + Let(n, te, r, RecursionKind.NonRecursive, m.tag) } } } - case EvalResult.Constant(Lit.Str(_) | Lit.Chr(_)) => - // TODO, we can match some of these statically - None } } diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala index a120457ac..1dd33a453 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala @@ -11,6 +11,7 @@ import org.bykn.bosatsu.{ Parser, } import org.bykn.bosatsu.codegen.Idents +import org.bykn.bosatsu.pattern.StrPart import org.bykn.bosatsu.rankn.Type import org.typelevel.paiges.Doc diff --git a/core/src/main/scala/org/bykn/bosatsu/pattern/Matcher.scala b/core/src/main/scala/org/bykn/bosatsu/pattern/Matcher.scala index 46be1b555..e0a1f531b 100644 --- a/core/src/main/scala/org/bykn/bosatsu/pattern/Matcher.scala +++ b/core/src/main/scala/org/bykn/bosatsu/pattern/Matcher.scala @@ -39,8 +39,8 @@ object Matcher { } } - val charMatcher: Matcher[Char, Char, Unit] = eqMatcher( - Eq.fromUniversalEquals[Char] + val intMatcher: Matcher[Int, Int, Unit] = eqMatcher( + Eq.fromUniversalEquals[Int] ) def fnMatch[A]: Matcher[A => Boolean, A, Unit] = diff --git a/core/src/main/scala/org/bykn/bosatsu/pattern/SeqPattern.scala b/core/src/main/scala/org/bykn/bosatsu/pattern/SeqPattern.scala index 1f8c65b0a..4a70f36d1 100644 --- a/core/src/main/scala/org/bykn/bosatsu/pattern/SeqPattern.scala +++ b/core/src/main/scala/org/bykn/bosatsu/pattern/SeqPattern.scala @@ -746,6 +746,6 @@ object SeqPattern { } } - val stringUnitMatcher: Matcher[SeqPattern[Char], String, Unit] = + val stringUnitMatcher: Matcher[SeqPattern[Int], String, Unit] = matcher(Splitter.stringUnit) } diff --git a/core/src/main/scala/org/bykn/bosatsu/pattern/Splitter.scala b/core/src/main/scala/org/bykn/bosatsu/pattern/Splitter.scala index 8ec55a594..b5cdf21b4 100644 --- a/core/src/main/scala/org/bykn/bosatsu/pattern/Splitter.scala +++ b/core/src/main/scala/org/bykn/bosatsu/pattern/Splitter.scala @@ -2,6 +2,8 @@ package org.bykn.bosatsu.pattern import cats.Monoid +import org.bykn.bosatsu.StringUtil + import cats.implicits._ abstract class Splitter[-Elem, Item, Sequence, R] { @@ -25,37 +27,49 @@ abstract class Splitter[-Elem, Item, Sequence, R] { object Splitter { def stringSplitter[R]( - fn: Char => R - )(implicit m: Monoid[R]): Splitter[Char, Char, String, R] = - new Splitter[Char, Char, String, R] { + fn: Int => R + )(implicit m: Monoid[R]): Splitter[Int, Int, String, R] = + new Splitter[Int, Int, String, R] { val matcher = - Matcher.charMatcher + Matcher.intMatcher .mapWithInput((s, _) => fn(s)) val monoidResult = m - def positions(c: Char): String => LazyList[(String, Char, R, String)] = { - str => - def loop(init: Int): LazyList[(String, Char, R, String)] = - if (init >= str.length) LazyList.empty - else if (str.charAt(init) == c) { - ( - str.substring(0, init), - c, - fn(c), - str.substring(init + 1) - ) #:: loop(init + 1) - } else loop(init + 1) + def positions(c: Int): String => LazyList[(String, Int, R, String)] = { + str => { + def loop(strOffset: Int): LazyList[(String, Int, R, String)] = + if (strOffset >= str.length) LazyList.empty + else { + // we have to skip the entire codepoint + val cp = str.codePointAt(strOffset) + val csize = Character.charCount(cp) + if (cp == c) { + // this is a valid match + ( + str.substring(0, strOffset), + c, + fn(c), + str.substring(strOffset + csize) + ) #:: loop(strOffset + csize) + } else { + loop(strOffset + csize) + } + } loop(0) + } } - def anySplits(str: String): LazyList[(String, Char, R, String)] = - (0 until str.length) + def anySplits(str: String): LazyList[(String, Int, R, String)] = + (0 until str.codePointCount(0, str.length)) .to(LazyList) .map { idx => - val prefix = str.substring(0, idx) - val post = str.substring(idx + 1) - val c = str.charAt(idx) + // we have to skip to valid offsets + val offset = str.offsetByCodePoints(0, idx) + val prefix = str.substring(0, offset) + val c = str.codePointAt(offset) + val csize = Character.charCount(c) + val post = str.substring(offset + csize) (prefix, c, fn(c), post) } @@ -63,17 +77,27 @@ object Splitter { def uncons(s: String) = if (s.isEmpty) None - else Some((s.head, s.tail)) + else { + val c = s.codePointAt(0) + val csize = Character.charCount(c) + Some((c, s.substring(csize))) + } + + def toStr(cp: Int): String = + (new java.lang.StringBuilder).appendCodePoint(cp).toString - def cons(c: Char, s: String) = s"$c$s" + def cons(c: Int, s: String) = s"${toStr(c)}$s" def emptySeq = "" def catSeqs(s: List[String]) = s.mkString - override def toList(s: String) = s.toList - override def fromList(cs: List[Char]) = cs.mkString + override def toList(s: String) = + StringUtil.codePoints(s) + + override def fromList(cs: List[Int]) = + cs.iterator.map(toStr(_)).mkString } - val stringUnit: Splitter[Char, Char, String, Unit] = + val stringUnit: Splitter[Int, Int, String, Unit] = stringSplitter(_ => ()) abstract class ListSplitter[P, V, R] extends Splitter[P, V, List[V], R] { diff --git a/core/src/main/scala/org/bykn/bosatsu/pattern/StrPart.scala b/core/src/main/scala/org/bykn/bosatsu/pattern/StrPart.scala new file mode 100644 index 000000000..f319269f6 --- /dev/null +++ b/core/src/main/scala/org/bykn/bosatsu/pattern/StrPart.scala @@ -0,0 +1,221 @@ +package org.bykn.bosatsu.pattern + +import cats.Monoid +import org.bykn.bosatsu.{Pattern, Lit, Identifier} + +sealed abstract class StrPart +object StrPart { + sealed abstract class Glob(val capture: Boolean) extends StrPart + sealed abstract class CharPart(val capture: Boolean) extends StrPart + case object WildStr extends Glob(false) + case object IndexStr extends Glob(true) + case object WildChar extends CharPart(false) + case object IndexChar extends CharPart(true) + case class LitStr(asString: String) extends StrPart + + sealed abstract class MatchSize(val isExact: Boolean) { + def charCount: Int + def canMatch(cp: Int): Boolean + // we know chars/2 <= cpCount <= chars for utf16 + def canMatchUtf16Count(chars: Int): Boolean + } + object MatchSize { + case class Exactly(charCount: Int) extends MatchSize(true) { + def canMatch(cp: Int): Boolean = cp == charCount + def canMatchUtf16Count(chars: Int): Boolean = { + val cpmin = chars / 2 + val cpmax = chars + (cpmin <= charCount) && (charCount <= cpmax) + } + } + case class AtLeast(charCount: Int) extends MatchSize(false) { + def canMatch(cp: Int): Boolean = charCount <= cp + def canMatchUtf16Count(chars: Int): Boolean = { + val cpmax = chars + // we have any cp in [cpmin, cpmax] + // but we require charCount <= cp + (charCount <= cpmax) + } + } + + private val atLeast0 = AtLeast(0) + private val exactly0 = Exactly(0) + private val exactly1 = Exactly(1) + + def from(sp: StrPart): MatchSize = + sp match { + case _: Glob => atLeast0 + case _: CharPart => exactly1 + case LitStr(str) => + Exactly(str.codePointCount(0, str.length)) + } + + def apply[F[_]: cats.Foldable](f: F[StrPart]): MatchSize = + cats.Foldable[F].foldMap(f)(from) + + implicit val monoidMatchSize: Monoid[MatchSize] = + new Monoid[MatchSize] { + def empty: MatchSize = exactly0 + def combine(l: MatchSize, r: MatchSize) = + if (l.isExact && r.isExact) Exactly(l.charCount + r.charCount) + else AtLeast(l.charCount + r.charCount) + } + } + + private[this] val emptyStringArray: Array[String] = new Array[String](0) + /** + * This performs the matchstring algorithm on a literal string + * it returns null if there is no match, or the array of binds + * in order if there is a match + */ + def matchString( + str: String, + pat: List[StrPart], + binds: Int + ): Array[String] = { + val strLen = str.length() + val results = + if (binds > 0) new Array[String](binds) else emptyStringArray + + def loop(offset: Int, pat: List[StrPart], next: Int): Boolean = + pat match { + case Nil => offset == strLen + case LitStr(expect) :: tail => + val len = expect.length + str.regionMatches(offset, expect, 0, len) && loop( + offset + len, + tail, + next + ) + case (c: CharPart) :: tail => + try { + val nextOffset = str.offsetByCodePoints(offset, 1) + val n = + if (c.capture) { + results(next) = str.substring(offset, nextOffset) + next + 1 + } else next + + loop(nextOffset, tail, n) + } catch { + case _: IndexOutOfBoundsException => false + } + case (h: Glob) :: tail => + tail match { + case Nil => + // we capture all the rest + if (h.capture) { + results(next) = str.substring(offset) + } + true + case rest @ ((_: CharPart) :: _) => + val matchableSizes = MatchSize[List](rest) + + def canMatch(off: Int): Boolean = + matchableSizes.canMatch(str.codePointCount(off, strLen)) + + // (.*)(.)tail2 + // this is a naive algorithm that just + // checks at all possible later offsets + // a smarter algorithm could see if there + // are Lit parts that can match or not + var matched = false + var off1 = offset + val n1 = if (h.capture) (next + 1) else next + while (!matched && (off1 < strLen)) { + matched = canMatch(off1) && loop(off1, rest, n1) + if (!matched) { + off1 = off1 + Character.charCount(str.codePointAt(off1)) + } + } + + matched && { + if (h.capture) { + results(next) = str.substring(offset, off1) + } + true + } + case LitStr(expect) :: tail2 => + val next1 = if (h.capture) next + 1 else next + + val matchableSizes = MatchSize(tail2) + + def canMatch(off: Int): Boolean = + matchableSizes.canMatchUtf16Count(strLen - off) + + var start = offset + var result = false + while (start >= 0) { + val candidate = str.indexOf(expect, start) + if (candidate >= 0) { + // we have to skip the current expect string + val nextOff = candidate + expect.length + val check1 = + canMatch(nextOff) && loop(nextOff, tail2, next1) + if (check1) { + // this was a match, write into next if needed + if (h.capture) { + results(next) = str.substring(offset, candidate) + } + result = true + start = -1 + } else { + // we couldn't match here, try just after candidate + start = candidate + Character.charCount( + str.codePointAt(candidate) + ) + } + } else { + // no more candidates + start = -1 + } + } + result + // $COVERAGE-OFF$ + case (_: Glob) :: _ => + // this should be an error at compile time since it + // is never meaningful to have two adjacent globs + sys.error(s"invariant violation, adjacent globs: $pat") + // $COVERAGE-ON$ + } + } + + if (loop(0, pat, 0)) results else null + } + + def matchPattern(str: String, pattern: Pattern.StrPat): Option[List[(Identifier.Bindable, Lit.StringMatchResult)]] = { + val partList = pattern.parts.toList + + val sbinds: List[String => (Identifier.Bindable, Lit.StringMatchResult)] = + partList + .collect { + // that each name is distinct + // should be checked in the SourceConverter/TotalityChecking code + case Pattern.StrPart.NamedStr(n) => + { (value: String) => (n, Lit.Str(value)) } + case Pattern.StrPart.NamedChar(n) => + { (value: String) => (n, Lit.Chr(value)) } + } + + val pat = partList.map { + case Pattern.StrPart.NamedStr(_) => StrPart.IndexStr + case Pattern.StrPart.NamedChar(_) => StrPart.IndexChar + case Pattern.StrPart.WildStr => StrPart.WildStr + case Pattern.StrPart.WildChar => StrPart.WildChar + case Pattern.StrPart.LitStr(s) => StrPart.LitStr(s) + } + + val result = StrPart.matchString(str, pat, sbinds.length) + if (result == null) None + else { + // we match: + val matched = result + .iterator + .zip(sbinds.iterator) + .map { case (m, fn) => fn(m) } + .toList + + Some(matched) + } + } +} \ No newline at end of file diff --git a/core/src/test/scala/org/bykn/bosatsu/Gen.scala b/core/src/test/scala/org/bykn/bosatsu/Gen.scala index 9920941df..ad88e3530 100644 --- a/core/src/test/scala/org/bykn/bosatsu/Gen.scala +++ b/core/src/test/scala/org/bykn/bosatsu/Gen.scala @@ -18,6 +18,25 @@ object Generators { val num: Gen[Char] = Gen.oneOf('0' to '9') val identC: Gen[Char] = Gen.frequency((10, lower), (1, upper), (1, num)) + val genCodePoints: Gen[Int] = + Gen.frequency( + (10, Gen.choose(0, 0xd7ff)), + ( + 1, + Gen.choose(0, 0x10ffff).filterNot { cp => + (0xd800 <= cp && cp <= 0xdfff) + } + ) + ) + + val genValidUtf: Gen[String] = + Gen.listOf(genCodePoints) + .map { points => + val bldr = new java.lang.StringBuilder + points.foreach(bldr.appendCodePoint(_)) + bldr.toString + } + val whiteSpace: Gen[String] = Gen.listOf(Gen.oneOf(' ', '\t', '\n')).map(_.mkString) @@ -561,15 +580,22 @@ object Generators { ): NonEmptyList[Pattern.StrPart] = nel match { case NonEmptyList(_, Nil) => nel - case NonEmptyList(h1, h2 :: t) if isWild(h1) && isWild(h2) => - makeValid(NonEmptyList(h2, t)) + case NonEmptyList(h1, h2 :: t) + if Pattern.StrPat(NonEmptyList.one(h1)).names.exists(Pattern.StrPat(NonEmptyList(h2, t)).names.toSet) => + makeValid(NonEmptyList(h2, t)) case NonEmptyList( Pattern.StrPart.LitStr(h1), Pattern.StrPart.LitStr(h2) :: t ) => makeValid(NonEmptyList(Pattern.StrPart.LitStr(h1 + h2), t)) case NonEmptyList(h1, h2 :: t) => - NonEmptyList(h1, makeValid(NonEmptyList(h2, t)).toList) + val tail = makeValid(NonEmptyList(h2, t)) + if (isWild(tail.head) && isWild(h1)) { + tail + } + else { + h1 :: tail + } } for { diff --git a/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala b/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala index ffa5d05e9..d1cb4e734 100644 --- a/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala @@ -14,7 +14,7 @@ import cats.implicits._ import cats.parse.{Parser0 => P0, Parser => P} import Parser.{optionParse, unsafeParse, Indy} -import Generators.{shrinkDecl, shrinkStmt} +import Generators.{shrinkDecl, shrinkStmt, genCodePoints} import org.scalatest.funsuite.AnyFunSuite trait ParseFns { @@ -271,24 +271,6 @@ class ParserTest extends ParserTestBase { StringUtil.utf16Codepoint.repAs(StringUtil.codePointAccumulator) | P.pure( "" ) - val genCodePoints: Gen[Int] = - Gen.frequency( - (10, Gen.choose(0, 0xd7ff)), - ( - 1, - Gen.choose(0, 0x10ffff).filterNot { cp => - (0xd800 <= cp && cp <= 0xdfff) - } - ) - ) - - // .codePoints isn't available in scalajs - def jsCompatCodepoints(s: String): List[Int] = - if (s.isEmpty) Nil - else - (s.codePointAt(0) :: jsCompatCodepoints( - s.substring(s.offsetByCodePoints(0, 1)) - )) forAll(Gen.listOf(genCodePoints)) { cps => val strbuilder = new java.lang.StringBuilder @@ -300,8 +282,8 @@ class ParserTest extends ParserTestBase { assert(parsed == Right(str)) assert( - parsed.map(jsCompatCodepoints) == Right(cps), - s"hex = $hex, str = ${jsCompatCodepoints(str)} utf16 = ${str.toCharArray().toList.map(_.toInt.toHexString)}" + parsed.map(StringUtil.codePoints) == Right(cps), + s"hex = $hex, str = ${StringUtil.codePoints(str)} utf16 = ${str.toCharArray().toList.map(_.toInt.toHexString)}" ) } } diff --git a/core/src/test/scala/org/bykn/bosatsu/pattern/SeqPatternTest.scala b/core/src/test/scala/org/bykn/bosatsu/pattern/SeqPatternTest.scala index 45aa0562d..66651ca3d 100644 --- a/core/src/test/scala/org/bykn/bosatsu/pattern/SeqPatternTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/pattern/SeqPatternTest.scala @@ -1,6 +1,7 @@ package org.bykn.bosatsu.pattern import org.bykn.bosatsu.set.{Rel, SetOps} +import org.bykn.bosatsu.StringUtil import org.scalacheck.{Arbitrary, Gen, Shrink} import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, @@ -15,9 +16,11 @@ import org.scalatest.funsuite.AnyFunSuite object StringSeqPatternGen { - type Named = NamedSeqPattern[Char] + type Named = NamedSeqPattern[Int] val Named = NamedSeqPattern + def lit(c: Char): SeqPart[Int] = Lit(c.toInt) + // generate a string of 0s and 1s to make matches more likely val genBitString: Gen[String] = for { @@ -26,17 +29,17 @@ object StringSeqPatternGen { list <- Gen.listOfN(sz, g) } yield list.mkString - val genPart: Gen[SeqPart[Char]] = { + val genPart: Gen[SeqPart[Int]] = { import SeqPart._ Gen.frequency( - (15, Gen.oneOf(Lit('0'), Lit('1'))), + (15, Gen.oneOf(lit('0'), lit('1'))), (2, Gen.const(AnyElem)), (1, Gen.const(Wildcard)) ) } - val genPat: Gen[SeqPattern[Char]] = { + val genPat: Gen[SeqPattern[Int]] = { val cat = for { h <- genPart @@ -46,9 +49,9 @@ object StringSeqPatternGen { Gen.frequency((1, Gen.const(Empty)), (5, cat)) } - implicit val arbPattern: Arbitrary[SeqPattern[Char]] = Arbitrary(genPat) + implicit val arbPattern: Arbitrary[SeqPattern[Int]] = Arbitrary(genPat) - implicit lazy val shrinkPat: Shrink[SeqPattern[Char]] = + implicit lazy val shrinkPat: Shrink[SeqPattern[Int]] = Shrink { case Empty => Empty #:: Stream.empty case Cat(Wildcard, t) => @@ -89,10 +92,10 @@ object StringSeqPatternGen { res } - val genNamed: Gen[NamedSeqPattern[Char]] = + val genNamed: Gen[NamedSeqPattern[Int]] = genNamedFn(genPart, 0).map(_._2) - implicit val arbNamed: Arbitrary[NamedSeqPattern[Char]] = Arbitrary(genNamed) + implicit val arbNamed: Arbitrary[NamedSeqPattern[Int]] = Arbitrary(genNamed) def interleave[A](s1: Stream[A], s2: Stream[A]): Stream[A] = if (s1.isEmpty) s2 @@ -101,7 +104,7 @@ object StringSeqPatternGen { s1.head #:: interleave(s2, s1.tail) } - implicit val shrinkNamedSeqPattern: Shrink[NamedSeqPattern[Char]] = + implicit val shrinkNamedSeqPattern: Shrink[NamedSeqPattern[Int]] = Shrink { case NamedSeqPattern.NEmpty => Stream.Empty case NamedSeqPattern.Bind(n, p) => @@ -109,8 +112,8 @@ object StringSeqPatternGen { val binded = sp.map(NamedSeqPattern.Bind(n, _)) interleave(sp, binded) case NamedSeqPattern.NSeqPart(p) => - def tail: Stream[SeqPart[Char]] = - Lit('0') #:: Lit('1') #:: Stream.Empty + def tail: Stream[SeqPart[Int]] = + lit('0') #:: lit('1') #:: Stream.Empty val sp = p match { case Wildcard => AnyElem #:: tail @@ -744,14 +747,14 @@ class BoolSeqPatternTest } } -class SeqPatternTest extends SeqPatternLaws[Char, Char, String, Unit] { +class SeqPatternTest extends SeqPatternLaws[Int, Int, String, Unit] { import SeqPart._ def genPattern = StringSeqPatternGen.genPat def genNamed = StringSeqPatternGen.genNamed def genSeq = StringSeqPatternGen.genBitString - def splitter: Splitter[Char, Char, String, Unit] = + def splitter: Splitter[Int, Int, String, Unit] = Splitter.stringSplitter(_ => ()) val pmatcher = Pattern.matcher(splitter) @@ -763,23 +766,23 @@ class SeqPatternTest extends SeqPatternLaws[Char, Char, String, Unit] { def namedMatch(p: Named, s: String): Option[Map[String, String]] = NamedSeqPattern.matcher(Splitter.stringSplitter(_.toString))(p)(s).map(_._2) - implicit val setOpsChar: SetOps[Char] = SetOps.distinct[Char] - def setOps: SetOps[Pattern] = Pattern.seqPatternSetOps[Char] + implicit val setOpsChar: SetOps[Int] = SetOps.distinct[Int] + def setOps: SetOps[Pattern] = Pattern.seqPatternSetOps[Int] import StringSeqPatternGen._ def toPattern(s: String): Pattern = - s.toList.foldRight(Pattern.Empty: Pattern) { (c, r) => + StringUtil.codePoints(s).foldRight(Pattern.Empty: Pattern) { (c, r) => Pattern.Cat(Lit(c), r) } override def diffUBRegressions - : List[(SeqPattern[Char], SeqPattern[Char], String)] = + : List[(SeqPattern[Int], SeqPattern[Int], String)] = List({ val p1 = Cat(AnyElem, Cat(Wildcard, Empty)) val p2 = Cat( Wildcard, - Cat(AnyElem, Cat(Lit('0'), Cat(Lit('1'), Cat(Wildcard, Empty)))) + Cat(AnyElem, Cat(lit('0'), Cat(lit('1'), Cat(Wildcard, Empty)))) ) (p1, p2, "11") }) @@ -875,7 +878,7 @@ class SeqPatternTest extends SeqPatternLaws[Char, Char, String, Unit] { test("Named.matches + render agree") { def law(n: Named, str: String) = namedMatch(n, str).foreach { m => - n.render(m)(_.toString) match { + n.render(m) { codePoint => StringUtil.fromCodePoints(codePoint :: Nil) } match { case Some(s0) => assert(s0 == str, s"m = $m") case None => // this can only happen if we have unnamed Wild/AnyElem @@ -905,9 +908,9 @@ class SeqPatternTest extends SeqPatternLaws[Char, Char, String, Unit] { import Named._ val regressions: List[(Named, String)] = - (NCat(NEmpty, NCat(NSeqPart(Lit('1')), NSeqPart(Wildcard))), "1") :: - (NCat(NSeqPart(Lit('1')), NSeqPart(Wildcard)), "1") :: - (NSeqPart(Lit('1')), "1") :: + (NCat(NEmpty, NCat(NSeqPart(lit('1')), NSeqPart(Wildcard))), "1") :: + (NCat(NSeqPart(lit('1')), NSeqPart(Wildcard)), "1") :: + (NSeqPart(lit('1')), "1") :: Nil regressions.foreach { case (n, s) => namedMatchesPatternLaw(n, s) } @@ -926,14 +929,14 @@ class SeqPatternTest extends SeqPatternLaws[Char, Char, String, Unit] { test("regression subset example") { val p1 = Cat(AnyElem, Cat(Wildcard, Empty)) - val p2 = Cat(Wildcard, Cat(Lit('0'), Cat(Lit('1'), Empty))) + val p2 = Cat(Wildcard, Cat(lit('0'), Cat(lit('1'), Empty))) assert(setOps.subset(p2, p1)) } test("intersection regression") { - val p1 = Cat(Wildcard, Cat(Lit('0'), Cat(Lit('1'), Empty))) - val p2 = Cat(Lit('0'), Cat(Lit('0'), Cat(Lit('0'), Cat(Wildcard, Empty)))) + val p1 = Cat(Wildcard, Cat(lit('0'), Cat(lit('1'), Empty))) + val p2 = Cat(lit('0'), Cat(lit('0'), Cat(lit('0'), Cat(Wildcard, Empty)))) assert(setOps.relate(p2, p1) == Rel.Intersects) assert(setOps.intersection(p1, p2).nonEmpty) diff --git a/core/src/test/scala/org/bykn/bosatsu/pattern/StrPartTest.scala b/core/src/test/scala/org/bykn/bosatsu/pattern/StrPartTest.scala new file mode 100644 index 000000000..72e63d1d8 --- /dev/null +++ b/core/src/test/scala/org/bykn/bosatsu/pattern/StrPartTest.scala @@ -0,0 +1,103 @@ +package org.bykn.bosatsu.pattern + +import org.bykn.bosatsu.Generators.{genStrPat, genValidUtf} +import org.bykn.bosatsu.StringUtil +import org.scalacheck.Prop.forAll +import org.scalacheck.Gen + +class StrPartTest extends munit.ScalaCheckSuite { + override def scalaCheckTestParameters = + super.scalaCheckTestParameters + .withMinSuccessfulTests(100000) + .withMaxDiscardRatio(10) + + val nonUnicode: Gen[String] = + Gen.oneOf( + Gen.asciiStr, + Gen.identifier, + ) + + val genStr: Gen[String] = + Gen.oneOf( + Gen.asciiStr, + Gen.identifier, + genValidUtf + ) + + property("StringUtil codePoints work (used internally by matching)") { + forAll(genStr) { str => + val cp = StringUtil.codePoints(str) + val s1 = StringUtil.fromCodePoints(cp) + assertEquals(s1, str, s"codepoints = $cp") + } + } + + property("pat.matcher works for non-unicode strings") { + forAll(nonUnicode, genStrPat) { (str, pat) => + val sp = pat.toSeqPattern + StrPart.matchPattern(str, pat) match { + case Some(binds) => + val justNames = binds.map(_._1) + assertEquals(justNames.distinct, justNames) + assertEquals(pat.names, justNames) + // this should agree with the matches method + assert(pat.matcher(str).isDefined, s"seqPattern = $sp, named = ${pat.toNamedSeqPattern}") + case None => + assert(pat.matcher(str).isEmpty, s"seqPattern = $sp, named = ${pat.toNamedSeqPattern}") + } + } + } + + property("matches finds all the bindings in order (unicode)") { + forAll(genStr, genStrPat) { (str, pat) => + StrPart.matchPattern(str, pat) match { + case Some(binds) => + val justNames = binds.map(_._1) + assertEquals(justNames.distinct, justNames) + assertEquals(pat.names, justNames) + // this should agree with the matches method + assert(pat.matches(str)) + case None => + assert(!pat.matches(str)) + } + } + } + + property("matchPattern agrees with NamedPattern.matcher (unicode)") { + val nm = NamedSeqPattern.matcher(Splitter.stringUnit) + forAll(genStr, genStrPat) { (str, pat) => + val namedMatcher = nm(pat.toNamedSeqPattern) + + val res1 = StrPart.matchPattern(str, pat).map { pairs => + pairs.map { case (b, sr) => + (b.sourceCodeRepr, sr.asStr) + } + .toMap + } + val matchRes = namedMatcher(str).map(_._2) + assertEquals(matchRes, res1) + } + } + + property("matches agrees with toRegex") { + forAll(genStr, genStrPat) { (str, pat) => + val re = pat.toRegex + val matcher = re.matcher(str) + + StrPart.matchPattern(str, pat) match { + case Some(binds) => + assert(matcher.matches(), s"binds = $binds, re = $re") + val reMatches = (1 to matcher.groupCount) + .map { idx => + matcher.group(idx) + } + .toList + + // TODO: this fails + assertEquals(pat.names.zip(reMatches), binds.map { case (k, v) => (k, v.asStr)}) + case None => + assert(!matcher.matches()) + } + } + } +} \ No newline at end of file diff --git a/core/src/test/scala/org/bykn/bosatsu/pattern/StringSeqPatternSetLaws.scala b/core/src/test/scala/org/bykn/bosatsu/pattern/StringSeqPatternSetLaws.scala index 32dcc759e..5d6dd2183 100644 --- a/core/src/test/scala/org/bykn/bosatsu/pattern/StringSeqPatternSetLaws.scala +++ b/core/src/test/scala/org/bykn/bosatsu/pattern/StringSeqPatternSetLaws.scala @@ -6,12 +6,14 @@ import cats.Eq import org.scalacheck.Gen import SeqPattern.{Cat, Empty} -import SeqPart.{AnyElem, Lit, Wildcard} +import SeqPart.{AnyElem, Wildcard} import cats.implicits._ -class StringSeqPatternSetLaws extends SetOpsLaws[SeqPattern[Char]] { - type Pattern = SeqPattern[Char] +import StringSeqPatternGen.lit + +class StringSeqPatternSetLaws extends SetOpsLaws[SeqPattern[Int]] { + type Pattern = SeqPattern[Int] val Pattern = SeqPattern override def scalaCheckTestParameters = @@ -66,21 +68,20 @@ class StringSeqPatternSetLaws extends SetOpsLaws[SeqPattern[Char]] { } } - implicit val setOpsChar: SetOps[Char] = SetOps.distinct[Char] - val setOps: SetOps[Pattern] = Pattern.seqPatternSetOps[Char] + implicit val setOpsChar: SetOps[Int] = SetOps.distinct[Int] + val setOps: SetOps[Pattern] = Pattern.seqPatternSetOps[Int] test("subsetConsistencyLaw regressions") { import SeqPattern.{Cat, Empty} - import SeqPart.Lit - val regressions: List[(SeqPattern[Char], SeqPattern[Char])] = + val regressions: List[(SeqPattern[Int], SeqPattern[Int])] = ( - Cat(Lit('1'), Cat(Lit('1'), Cat(Lit('1'), Cat(Lit('1'), Empty)))), - Cat(Lit('0'), Cat(Lit('1'), Cat(Lit('1'), Empty))) + Cat(lit('1'), Cat(lit('1'), Cat(lit('1'), Cat(lit('1'), Empty)))), + Cat(lit('0'), Cat(lit('1'), Cat(lit('1'), Empty))) ) :: ( - Cat(Lit('1'), Cat(Lit('0'), Cat(Lit('1'), Cat(Lit('0'), Empty)))), - Cat(Lit('0'), Cat(Lit('1'), Empty)) + Cat(lit('1'), Cat(lit('0'), Cat(lit('1'), Cat(lit('0'), Empty)))), + Cat(lit('0'), Cat(lit('1'), Empty)) ) :: Nil @@ -91,14 +92,14 @@ class StringSeqPatternSetLaws extends SetOpsLaws[SeqPattern[Char]] { test("*x* problems") { import SeqPattern.{Cat, Empty} - import SeqPart.{Lit, Wildcard} + import SeqPart.Wildcard val x = Cat( Wildcard, - Cat(Lit('q'), Cat(Wildcard, Cat(Lit('p'), Cat(Wildcard, Empty)))) + Cat(lit('q'), Cat(Wildcard, Cat(lit('p'), Cat(Wildcard, Empty)))) ) - val y = Cat(Wildcard, Cat(Lit('p'), Cat(Wildcard, Empty))) - val z = Cat(Wildcard, Cat(Lit('q'), Cat(Wildcard, Empty))) + val y = Cat(Wildcard, Cat(lit('p'), Cat(Wildcard, Empty))) + val z = Cat(Wildcard, Cat(lit('q'), Cat(Wildcard, Empty))) // note y and z are clearly bigger than x because they are prefix/suffix that end/start with // Wildcard assert(setOps.difference(x, y).isEmpty) @@ -107,29 +108,29 @@ class StringSeqPatternSetLaws extends SetOpsLaws[SeqPattern[Char]] { test("(a - b) n c = (a n c) - (b n c) regressions") { val regressions - : List[(SeqPattern[Char], SeqPattern[Char], SeqPattern[Char])] = + : List[(SeqPattern[Int], SeqPattern[Int], SeqPattern[Int])] = ( Cat(Wildcard, Empty), - Cat(AnyElem, Cat(Lit('1'), Cat(AnyElem, Empty))), - Cat(AnyElem, Cat(Lit('1'), Cat(Lit('0'), Empty))) + Cat(AnyElem, Cat(lit('1'), Cat(AnyElem, Empty))), + Cat(AnyElem, Cat(lit('1'), Cat(lit('0'), Empty))) ) :: ( - Cat(Wildcard, Cat(Lit('0'), Empty)), - Cat(AnyElem, Cat(Lit('1'), Cat(AnyElem, Cat(Lit('0'), Empty)))), - Cat(AnyElem, Cat(Lit('1'), Cat(Lit('0'), Cat(Lit('0'), Empty)))) + Cat(Wildcard, Cat(lit('0'), Empty)), + Cat(AnyElem, Cat(lit('1'), Cat(AnyElem, Cat(lit('0'), Empty)))), + Cat(AnyElem, Cat(lit('1'), Cat(lit('0'), Cat(lit('0'), Empty)))) ) :: ( - Cat(Wildcard, Cat(Lit('q'), Cat(Wildcard, Empty))), + Cat(Wildcard, Cat(lit('q'), Cat(Wildcard, Empty))), Cat(Wildcard, Empty), - Cat(Wildcard, Cat(Lit('p'), Cat(Wildcard, Empty))) + Cat(Wildcard, Cat(lit('p'), Cat(Wildcard, Empty))) ) :: /* * This fails currently * see: https://github.com/johnynek/bosatsu/issues/486 { - val p1 = Cat(Wildcard,Cat(Lit('1'),Cat(Lit('0'),Cat(Lit('0'),Empty)))) - val p2 = Cat(AnyElem,Cat(Lit('1'),Cat(Wildcard,Cat(Lit('0'),Empty)))) - val p3 = Cat(Lit('1'),Cat(Lit('1'),Cat(Wildcard,Cat(Lit('0'),Empty)))) + val p1 = Cat(Wildcard,Cat(lit('1'),Cat(lit('0'),Cat(lit('0'),Empty)))) + val p2 = Cat(AnyElem,Cat(lit('1'),Cat(Wildcard,Cat(lit('0'),Empty)))) + val p3 = Cat(lit('1'),Cat(lit('1'),Cat(Wildcard,Cat(lit('0'),Empty)))) (p1, p2, p3) } :: */ @@ -139,8 +140,8 @@ class StringSeqPatternSetLaws extends SetOpsLaws[SeqPattern[Char]] { } test("intersection regression") { - val p1 = Cat(Wildcard, Cat(Lit('0'), Cat(Lit('1'), Empty))) - val p2 = Cat(Lit('0'), Cat(Lit('0'), Cat(Lit('0'), Cat(Wildcard, Empty)))) + val p1 = Cat(Wildcard, Cat(lit('0'), Cat(lit('1'), Empty))) + val p2 = Cat(lit('0'), Cat(lit('0'), Cat(lit('0'), Cat(Wildcard, Empty)))) assert(setOps.relate(p1, p2) == Rel.Intersects) assert(setOps.relate(p2, p1) == Rel.Intersects) From 52b8ec62a6ab2b60ed933c8f34337806f6e0b955 Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Wed, 20 Nov 2024 07:41:44 -1000 Subject: [PATCH 08/11] Implement searchList in C (#1264) * Implement searchList in C * minor formatting --- .../bykn/bosatsu/codegen/clang/ClangGen.scala | 111 +++++++++++++++++- .../org/bykn/bosatsu/codegen/clang/Code.scala | 12 +- .../bosatsu/codegen/python/PythonGen.scala | 4 +- .../bosatsu/codegen/clang/ClangGenTest.scala | 3 +- 4 files changed, 119 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala index 8251e8989..6b4064b1b 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala @@ -200,6 +200,7 @@ object ClangGen { def directFn(p: PackageName, b: Bindable): T[Option[Code.Ident]] def directFn(b: Bindable): T[Option[(Code.Ident, Boolean)]] def inTop[A](p: PackageName, bn: Bindable)(ta: T[A]): T[A] + def currentTop: T[Option[(PackageName, Bindable)]] def staticValueName(p: PackageName, b: Bindable): T[Code.Ident] def constructorFn(p: PackageName, b: Bindable): T[Code.Ident] @@ -318,10 +319,11 @@ object ClangGen { // this is just get_variant(expr) == expect vl.onExpr { expr => pv(Code.Ident("get_variant")(expr) =:= Code.IntLiteral(expect)) }(newLocalName) } - case sl @ SearchList(lst, init, check, leftAcc) => - // TODO: ??? - println(s"TODO: implement boolToValue($sl) returning false") - pv(Code.FalseLit) + case SearchList(lst, init, check, leftAcc) => + (boolToValue(check), innerToValue(init)) + .flatMapN { (condV, initV) => + searchList(lst, initV, condV, leftAcc) + } case ms @ MatchString(arg, parts, binds) => // TODO: ??? println(s"TODO: implement boolToValue($ms) returning false") @@ -334,6 +336,100 @@ object ClangGen { case TrueConst => pv(Code.TrueLit) } + def searchList( + locMut: LocalAnonMut, + initVL: Code.ValueLike, + checkVL: Code.ValueLike, + optLeft: Option[LocalAnonMut] + ): T[Code.ValueLike] = { + import Code.Expression + + val emptyList: Expression = + Code.Ident("alloc_enum0")(Code.IntLiteral(0)) + + def isNonEmptyList(expr: Expression): Expression = + Code.Ident("get_variant")(expr) =:= Code.IntLiteral(1) + + def headList(expr: Expression): Expression = + Code.Ident("get_enum_index")(expr, Code.IntLiteral(0)) + + def tailList(expr: Expression): Expression = + Code.Ident("get_enum_index")(expr, Code.IntLiteral(1)) + + def consList(head: Expression, tail: Expression): Expression = + Code.Ident("alloc_enum2")(Code.IntLiteral(1), head, tail) + /* + * here is the implementation from MatchlessToValue + * + Dynamic { (scope: Scope) => + var res = false + var currentList = initF(scope) + var leftList = VList.VNil + while (currentList ne null) { + currentList match { + case nonempty@VList.Cons(head, tail) => + scope.updateMut(mutV, nonempty) + scope.updateMut(left, leftList) + res = checkF(scope) + if (res) { currentList = null } + else { + currentList = tail + leftList = VList.Cons(head, leftList) + } + case _ => + currentList = null + // we don't match empty lists + } + } + res + } + */ + for { + currentList <- getAnon(locMut.ident) + optLeft <- optLeft.traverse(lm => getAnon(lm.ident)) + res <- newLocalName("result") + tmpList <- newLocalName("tmp_list") + declTmpList <- Code.ValueLike.declareVar(Code.TypeIdent.BValue, tmpList, initVL)(newLocalName) + /* + top <- currentTop + _ = println(s"""in $top: searchList( + $locMut: LocalAnonMut, + $initVL: Code.ValueLike, + $checkVL: Code.ValueLike, + $optLeft: Option[LocalAnonMut] + )""") + */ + } yield + (Code + .Statements( + Code.DeclareVar(Nil, Code.TypeIdent.Bool, res, Some(Code.FalseLit)), + declTmpList + ) + .maybeCombine( + optLeft.map(_ := emptyList), + ) + + // we don't match empty lists, so if currentList reaches Empty we are done + Code.While( + isNonEmptyList(tmpList), + Code.block( + currentList := tmpList, + res := checkVL, + Code.ifThenElse(res, + { tmpList := emptyList }, + { + (tmpList := tailList(tmpList)) + .maybeCombine( + optLeft.map { left => + left := consList(headList(currentList), left) + } + ) + } + ) + ) + ) + ) :+ res + } + // We have to lift functions to the top level and not // create any nesting def innerFn(fn: FnExpr): T[Code.ValueLike] = @@ -509,7 +605,7 @@ object ClangGen { for { name <- getBinding(arg) result <- innerToValue(in) - stmt <- Code.ValueLike.declareVar(name, Code.TypeIdent.BValue, v)(newLocalName) + stmt <- Code.ValueLike.declareVar(Code.TypeIdent.BValue, name, v)(newLocalName) } yield stmt +: result } } @@ -521,7 +617,7 @@ object ClangGen { for { name <- getAnon(idx) result <- innerToValue(in) - stmt <- Code.ValueLike.declareVar(name, Code.TypeIdent.BValue, v)(newLocalName) + stmt <- Code.ValueLike.declareVar(Code.TypeIdent.BValue, name, v)(newLocalName) } yield stmt +: result } } @@ -941,6 +1037,9 @@ object ClangGen { _ <- StateT { (s: State) => result(s.copy(currentTop = None), ()) } } yield a + val currentTop: T[Option[(PackageName, Bindable)]] = + StateT { (s: State) => result(s, s.currentTop) } + def staticValueName(p: PackageName, b: Bindable): T[Code.Ident] = monadImpl.pure(Code.Ident(Idents.escape("___bsts_s_", fullName(p, b)))) def constructorFn(p: PackageName, b: Bindable): T[Code.Ident] = diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala index cb0d1d3f0..ab19634e3 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala @@ -192,8 +192,8 @@ object Code { } def declareVar[F[_]: Monad]( - ident: Ident, tpe: TypeIdent, + ident: Ident, value: ValueLike)(newLocalName: String => F[Code.Ident]): F[Statement] = value.exprToStatement[F] { expr => Monad[F].pure(DeclareVar(Nil, tpe, ident, Some(expr))) @@ -329,8 +329,15 @@ object Code { sealed trait Statement extends Code { def +(stmt: Statement): Statement = Statements.combine(this, stmt) + def maybeCombine(that: Option[Statement]): Statement = + that match { + case Some(t) => Statements.combine(this, t) + case None => this + + } def :+(vl: ValueLike): ValueLike = (this +: vl) } + case class Assignment(target: Expression, value: Expression) extends Statement case class DeclareArray(tpe: TypeIdent, ident: Ident, values: Either[Int, List[Expression]]) extends Statement case class DeclareVar(attrs: List[Attr], tpe: TypeIdent, ident: Ident, value: Option[Expression]) extends Statement @@ -347,6 +354,9 @@ object Code { def apply(nel: NonEmptyList[Statement]): Statements = Statements(NonEmptyChain.fromNonEmptyList(nel)) + def apply(first: Statement, rest: Statement*): Statements = + Statements(NonEmptyChain.of(first, rest: _*)) + def combine(first: Statement, last: Statement): Statement = first match { case Statements(items) => diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala index 1dd33a453..35e3871fa 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala @@ -1267,10 +1267,10 @@ object PythonGen { case SearchList(locMut, init, check, optLeft) => // check to see if we can find a non-empty // list that matches check - (loop(init, slotName), boolExpr(check, slotName)).mapN { + (loop(init, slotName), boolExpr(check, slotName)).flatMapN { (initVL, checkVL) => searchList(locMut, initVL, checkVL, optLeft) - }.flatten + } } def matchString( diff --git a/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala b/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala index 8adebad0e..2ab5c401c 100644 --- a/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala @@ -1,8 +1,7 @@ package org.bykn.bosatsu.codegen.clang import cats.data.NonEmptyList -import org.bykn.bosatsu.{PackageName, TestUtils, Identifier, Predef} -import Identifier.Name +import org.bykn.bosatsu.{PackageName, TestUtils, Identifier} class ClangGenTest extends munit.FunSuite { def assertPredefFns(fns: String*)(matches: String)(implicit loc: munit.Location) = From 1470a5c0d93081260e7f32984ba1ce72435bbc06 Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Wed, 20 Nov 2024 19:23:51 -1000 Subject: [PATCH 09/11] Pass in MatchString if a string match must be total (which we already knew) (#1266) * checkpoint * improve MatchString and PythonGen code * fix use of must match in while * fix ternary test * polish * continue to improve code gen * fix tests * fix use of not as a function * fix another test OMG * move things into andCond * restore using python and --- .../bosatsu/codegen/python/CodeTest.scala | 22 ++- .../scala/org/bykn/bosatsu/Matchless.scala | 55 +++--- .../org/bykn/bosatsu/MatchlessToValue.scala | 2 +- .../main/scala/org/bykn/bosatsu/Pattern.scala | 19 ++ .../bykn/bosatsu/codegen/clang/ClangGen.scala | 2 +- .../bykn/bosatsu/codegen/python/Code.scala | 72 ++++++- .../bosatsu/codegen/python/PythonGen.scala | 186 +++++++++++------- 7 files changed, 255 insertions(+), 103 deletions(-) diff --git a/cli/src/test/scala/org/bykn/bosatsu/codegen/python/CodeTest.scala b/cli/src/test/scala/org/bykn/bosatsu/codegen/python/CodeTest.scala index bfac90f09..da0007aea 100644 --- a/cli/src/test/scala/org/bykn/bosatsu/codegen/python/CodeTest.scala +++ b/cli/src/test/scala/org/bykn/bosatsu/codegen/python/CodeTest.scala @@ -305,7 +305,9 @@ else: } test("x.evalAnd(False) == False") { forAll(genExpr(4)) { x => - assert(x.evalAnd(Code.Const.False) == Code.Const.False) + val sx = x.simplify + val res = x.evalAnd(Code.Const.False) + assert((res == Code.Const.False) || ((res == sx) && (sx == Code.Const.Zero))) assert(Code.Const.False.evalAnd(x) == Code.Const.False) } } @@ -521,7 +523,23 @@ else: assert(tern == f.simplify) } case whoKnows => - assert(tern == Code.Ternary(t.simplify, whoKnows, f.simplify)) + if (tern == whoKnows) { + (t.simplify, f.simplify) match { + case (Code.Const.One | Code.Const.True, Code.Const.False | Code.Const.Zero) => + () + case tf => + fail(s"$tern == $whoKnows but (t,f) = $tf") + } + } + else { + (t.simplify, f.simplify) match { + case (Code.Const.False | Code.Const.Zero, Code.Const.One | Code.Const.True) => + val not = Code.Not(whoKnows) + assert(tern == not) + case (ts, fs) => + assert(tern == Code.Ternary(ts, whoKnows, fs)) + } + } } } } diff --git a/core/src/main/scala/org/bykn/bosatsu/Matchless.scala b/core/src/main/scala/org/bykn/bosatsu/Matchless.scala index 34a4f8a03..3fea102ba 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Matchless.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Matchless.scala @@ -120,7 +120,8 @@ object Matchless { case class MatchString( arg: CheapExpr, parts: List[StrPart], - binds: List[LocalAnonMut] + binds: List[LocalAnonMut], + mustMatch: Boolean ) extends BoolExpr // set the mutable variable to the given expr and return true case class SetMut(target: LocalAnonMut, expr: Expr) extends BoolExpr @@ -132,7 +133,7 @@ object Matchless { case TrueConst | CheckVariant(_, _, _, _) | EqualsLit(_, _) | EqualsNat(_, _) => false - case MatchString(_, _, b) => b.nonEmpty + case MatchString(_, _, b, _) => b.nonEmpty case And(b1, b2) => hasSideEffect(b1) || hasSideEffect(b2) case SearchList(_, _, b, l) => l.nonEmpty || hasSideEffect(b) @@ -458,32 +459,36 @@ object Matchless { doesMatch(arg, p, mustMatch).map(_.map { case (l0, cond, bs) => (l0, cond, (v, arg) :: bs) }) - case Pattern.StrPat(items) => - val sbinds: List[Bindable] = - items.toList - .collect { - // that each name is distinct - // should be checked in the SourceConverter/TotalityChecking code - case Pattern.StrPart.NamedStr(n) => n - case Pattern.StrPart.NamedChar(n) => n - } + case strPat @ Pattern.StrPat(items) => + strPat.simplify match { + case Some(simpler) => doesMatch(arg, simpler, mustMatch) + case None => + val sbinds: List[Bindable] = + items.toList + .collect { + // that each name is distinct + // should be checked in the SourceConverter/TotalityChecking code + case Pattern.StrPart.NamedStr(n) => n + case Pattern.StrPart.NamedChar(n) => n + } - val pat = items.toList.map { - case Pattern.StrPart.NamedStr(_) => StrPart.IndexStr - case Pattern.StrPart.NamedChar(_) => StrPart.IndexChar - case Pattern.StrPart.WildStr => StrPart.WildStr - case Pattern.StrPart.WildChar => StrPart.WildChar - case Pattern.StrPart.LitStr(s) => StrPart.LitStr(s) - } + val pat = items.toList.map { + case Pattern.StrPart.NamedStr(_) => StrPart.IndexStr + case Pattern.StrPart.NamedChar(_) => StrPart.IndexChar + case Pattern.StrPart.WildStr => StrPart.WildStr + case Pattern.StrPart.WildChar => StrPart.WildChar + case Pattern.StrPart.LitStr(s) => StrPart.LitStr(s) + } - sbinds.traverse { b => - makeAnon.map(LocalAnonMut(_)).map((b, _)) - } - .map { binds => - val ms = binds.map(_._2) + sbinds.traverse { b => + makeAnon.map(LocalAnonMut(_)).map((b, _)) + } + .map { binds => + val ms = binds.map(_._2) - NonEmptyList.one((ms, MatchString(arg, pat, ms), binds)) - } + NonEmptyList.one((ms, MatchString(arg, pat, ms, mustMatch), binds)) + } + } case lp @ Pattern.ListPat(_) => lp.toPositionalStruct(empty, cons) match { case Right(p) => doesMatch(arg, p, mustMatch) diff --git a/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala b/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala index 19da5db6b..5c5b40078 100644 --- a/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala +++ b/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala @@ -201,7 +201,7 @@ object MatchlessToValue { scope.updateMut(mut, exprF(scope)) true } - case MatchString(str, pat, binds) => + case MatchString(str, pat, binds, _) => // do this before we evaluate the string binds match { case Nil => diff --git a/core/src/main/scala/org/bykn/bosatsu/Pattern.scala b/core/src/main/scala/org/bykn/bosatsu/Pattern.scala index eca2f87a2..203dc68e7 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Pattern.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Pattern.scala @@ -465,6 +465,25 @@ object Pattern { } } + /** + * Convert this to simpler pattern, if possible (such + * as Literal, Wild, Var) + */ + def simplify: Option[Pattern[Nothing, Nothing]] = + parts match { + case NonEmptyList(StrPart.WildStr, Nil) => Some(Pattern.WildCard) + case NonEmptyList(StrPart.NamedStr(n), Nil) => Some(Pattern.Var(n)) + case _ => + val allStrings = parts.traverse { + case StrPart.LitStr(s) => Some(s) + case _ => None + } + + allStrings.map { strs => + Pattern.Literal(Lit.Str(strs.combineAll)) + } + } + lazy val toNamedSeqPattern: NamedSeqPattern[Int] = StrPat.toNamedSeqPattern(this) diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala index 6b4064b1b..c3f8949dd 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala @@ -324,7 +324,7 @@ object ClangGen { .flatMapN { (condV, initV) => searchList(lst, initV, condV, leftAcc) } - case ms @ MatchString(arg, parts, binds) => + case ms @ MatchString(arg, parts, binds, mustMatch) => // TODO: ??? println(s"TODO: implement boolToValue($ms) returning false") pv(Code.FalseLit) diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/python/Code.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/python/Code.scala index dbcad533a..39b040c27 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/python/Code.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/python/Code.scala @@ -17,7 +17,16 @@ object Code { // Not necessarily code, but something that has a final value // this allows us to add IfElse as a an expression (which // is not yet valid python) a series of lets before an expression - sealed trait ValueLike + sealed trait ValueLike { + def returnsBool: Boolean = + this match { + case PyBool(_) => true + case _: Expression => false + case WithValue(_, v) => v.returnsBool + case IfElse(ifs, elseCond) => + elseCond.returnsBool && ifs.forall { case (_, v) => v.returnsBool } + } + } sealed abstract class Expression extends ValueLike with Code { @@ -61,7 +70,7 @@ object Code { evalPlus(that) def unary_! : Expression = - Code.Ident("not")(this) + Not(this) def evalMinus(that: Expression): Expression = eval(Const.Minus, that) @@ -148,6 +157,13 @@ object Code { case PyBool(b) => if (b) trueDoc else falseDoc + case Not(n) => + val nd = n match { + case Ident(_) | Parens(_) | PyBool(_) | PyInt(_) | Apply(_, _) | DotSelect(_, _) | SelectItem(_, _) | SelectRange(_, _, _) => + exprToDoc(n) + case p => par(exprToDoc(p)) + } + Doc.text("not ") + nd case Ident(i) => Doc.text(i) case o @ Op(_, _, _) => o.toDoc case Parens(inner @ Parens(_)) => exprToDoc(inner) @@ -272,6 +288,18 @@ object Code { def simplify: Expression = this def countOf(i: Ident) = if (i == this) 1 else 0 } + case class Not(arg: Expression) extends Expression { + def simplify: Expression = + arg.simplify match { + case Not(a) => a + case PyBool(b) => PyBool(true ^ b) + case Const.Zero => Const.True + case Const.One => Const.False + case other => Not(other) + } + + def countOf(i: Ident) = arg.countOf(i) + } // Binary operator used for +, -, and, == etc... case class Op(left: Expression, op: Operator, right: Expression) extends Expression { @@ -429,7 +457,7 @@ object Code { case Op(a, Const.And, b) => a.simplify match { case Const.True => b.simplify - case Const.False => Const.False + case as @ (Const.False | Const.Zero) => as case a1 => b.simplify match { case Const.True => a1 @@ -515,7 +543,17 @@ object Code { case PyInt(i) => if (i != BigInteger.ZERO) ifTrue.simplify else ifFalse.simplify case notStatic => - Ternary(ifTrue.simplify, notStatic, ifFalse.simplify) + + (ifTrue.simplify, ifFalse.simplify) match { + case (Const.One | Const.True, Const.Zero | Const.False) => + // this is just the condition + notStatic + case (Const.Zero | Const.False, Const.One | Const.True) => + // this is just the not(condition) + Not(notStatic) + case (st, sf) => + Ternary(st, notStatic, sf) + } } } case class MakeTuple(args: List[Expression]) extends Expression { @@ -614,6 +652,26 @@ object Code { elseCond: ValueLike ) extends ValueLike + object ValueLike { + def ifThenElse(c: Expression, t: ValueLike, e: ValueLike): ValueLike = + c match { + case PyBool(b) => if (b) t else e + case Const.Zero => e + case Const.One => t + case _=> + // we can't evaluate now + e match { + case IfElse(econds, eelse) => IfElse((c, t) :: econds, eelse) + case ex: Expression => + t match { + case tx: Expression => Ternary(tx, c, ex).simplify + case _ => IfElse(NonEmptyList.one((c, t)), ex) + } + case notIf => IfElse(NonEmptyList.one((c, t)), notIf) + } + } + } + ///////////////////////// // Here are all the Statements ///////////////////////// @@ -750,6 +808,7 @@ object Code { def substitute(subMap: Map[Ident, Expression], in: Expression): Expression = in match { case PyInt(_) | PyString(_) | PyBool(_) => in + case Not(n) => Not(substitute(subMap, n)) case i @ Ident(_) => subMap.get(i) match { case Some(value) => value @@ -803,8 +862,9 @@ object Code { def loop(ex: Expression, bound: Set[Ident]): Set[Ident] = ex match { case PyInt(_) | PyString(_) | PyBool(_) => Set.empty - case i @ Ident(_) => - if (bound(i)) Set.empty + case Not(e) => loop(e, bound) + case i @ Ident(n) => + if (pyKeywordList(n) || bound(i)) Set.empty else Set(i) case Op(left, _, right) => loop(left, bound) | loop(right, bound) case Parens(expr) => diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala index 35e3871fa..cba0cb7d8 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala @@ -269,43 +269,51 @@ object PythonGen { // $COVERAGE-ON$ } + def ifElse1( + cond: ValueLike, + tCase: ValueLike, + fCase: ValueLike): Env[ValueLike] = + + cond match { + case cx: Expression => + Env.pure(Code.ValueLike.ifThenElse(cx, tCase, fCase)) + case WithValue(cs, cv) if cv.returnsBool => + // Nest into the condition + ifElse1(cv, tCase, fCase).map(cs.withValue(_)) + case ife@ IfElse(ccond, celse) if ife.returnsBool => + // this is basically the distributive property over if/else + // if (if (c) then x else y) then z else w == + // if (c) then (if x then z else w) else (if y then z else w) + (ccond.traverse { case (c, ct) => + ifElse1(ct, tCase, fCase) + .map(r => (c, r)) + }, ifElse1(celse, tCase, fCase)) + .flatMapN { (conds, elseCase) => + ifElse(conds, elseCase) + } + case _ => + for { + // allocate a new unshadowable var + cv <- Env.newAssignableVar + res <- ifElse1(cv, tCase, fCase) + } yield (cv := cond).withValue(res) + } + def ifElse( conds: NonEmptyList[(ValueLike, ValueLike)], elseV: ValueLike - ): Env[ValueLike] = + ): Env[ValueLike] = { // for all the non-expression conditions, we need to defer evaluating them // until they are really needed - conds match { - case NonEmptyList((cx: Expression, t), Nil) => - (t, elseV) match { - case (tx: Expression, elseX: Expression) => - Env.pure(Ternary(tx, cx, elseX).simplify) - case _ => - Env.pure(IfElse(NonEmptyList.one((cx, t)), elseV)) + val (c, t) = conds.head + NonEmptyList.fromList(conds.tail) match { + case Some(rest) => + ifElse(rest, elseV).flatMap { fcase => + ifElse1(c, t, fcase) } - case NonEmptyList((cx: Expression, t), rh :: rt) => - val head = (cx, t) - ifElse(NonEmptyList(rh, rt), elseV).map { - case IfElse(crest, er) => - // preserve IfElse chains - IfElse(head :: crest, er) - case nestX: Expression => - t match { - case tx: Expression => - Ternary(tx, cx, nestX).simplify - case _ => - IfElse(NonEmptyList.one(head), nestX) - } - case nest => - IfElse(NonEmptyList.one(head), nest) - } - case NonEmptyList((cx, t), rest) => - for { - // allocate a new unshadowable var - cv <- Env.newAssignableVar - res <- ifElse(NonEmptyList((cv, t), rest), elseV) - } yield (cv := cx).withValue(res) + case None => ifElse1(c, t, elseV) } + } def ifElseS( cond: ValueLike, @@ -316,6 +324,15 @@ object PythonGen { case x: Expression => Env.pure(Code.ifElseS(x, thenS, elseS)) case WithValue(stmt, vl) => ifElseS(vl, thenS, elseS).map(stmt +: _) + case ife @ IfElse(ifs, elseCond) if ife.returnsBool => + // every branch has a statically known boolean result + (ifs.traverse { case (cond, t) => + ifElseS(t, thenS, elseS) + .map(s => (cond, s)) + }, ifElseS(elseCond, thenS, elseS)) + .mapN { (ifs, elseCond) => + Code.ifStatement(ifs, Some(elseCond)) + } case v => // this is a branch, don't multiply code by writing on each // branch, that could give an exponential blowup @@ -330,28 +347,21 @@ object PythonGen { def andCode(c1: ValueLike, c2: ValueLike): Env[ValueLike] = (c1, c2) match { - case (t: Expression, c2) if t.simplify == Code.Const.True => - Env.pure(c2) - case (_, x2: Expression) => - onLast(c1)(_.evalAnd(x2)) - case _ => - // we know that c2 is not a simple expression - // res = False - // if c1: - // res = c2 - Env.onLastM(c1) { x1 => - for { - res <- Env.newAssignableVar - ifstmt <- ifElseS(x1, res := c2, Code.Pass) - } yield { - Code - .block( - res := Code.Const.False, - ifstmt - ) - .withValue(res) - } + case (e1: Expression, _) => + c2 match { + case e2: Expression => Env.pure(e1.evalAnd(e2)) + case _ => + // and(x, y) == if x: y else: False + Env.pure(Code.ValueLike.ifThenElse(e1, c2, Code.Const.False)) } + case (IfElse(cs, e), x2: Expression) => + (cs.traverse { case (c, t) => andCode(t, x2).map(t => (c, t)) }, + andCode(e, x2) + ).mapN(IfElse(_, _)) + case (WithValue(s, c1), c2) => + andCode(c1, c2).map(s.withValue(_)) + case _ => + Env.onLastM(c1) { andCode(_, c2) } } def makeDef( @@ -395,14 +405,14 @@ object PythonGen { case Ternary(ifTrue, cond, ifFalse) => // both results are in the tail position (loop(ifTrue), loop(ifFalse)).mapN { (t, f) => - ifElse(NonEmptyList.one((cond, t)), f) + ifElse1(cond, t, f) }.flatten case WithValue(stmt, v) => loop(v).map(stmt.withValue(_)) // the rest cannot have a call in the tail position case DotSelect(_, _) | Op(_, _, _) | Lambda(_, _) | MakeTuple(_) | MakeList(_) | SelectItem(_, _) | SelectRange(_, _, _) | Ident(_) | - PyBool(_) | PyString(_) | PyInt(_) => + PyBool(_) | PyString(_) | PyInt(_) | Not(_) => Env.pure(body) } @@ -1257,12 +1267,12 @@ object PythonGen { (ident := resx).withValue(Code.Const.True) } }.flatten - case MatchString(str, pat, binds) => + case MatchString(str, pat, binds, mustMatch) => ( loop(str, slotName), binds.traverse { case LocalAnonMut(m) => Env.nameForAnon(m) } ).mapN { (strVL, binds) => - Env.onLastM(strVL)(matchString(_, pat, binds)) + Env.onLastM(strVL)(matchString(_, pat, binds, mustMatch)) }.flatten case SearchList(locMut, init, check, optLeft) => // check to see if we can find a non-empty @@ -1276,7 +1286,8 @@ object PythonGen { def matchString( strEx: Expression, pat: List[StrPart], - binds: List[Code.Ident] + binds: List[Code.Ident], + mustMatch: Boolean ): Env[ValueLike] = { import StrPart.{LitStr, Glob, CharPart} val bindArray = binds.toArray @@ -1285,18 +1296,24 @@ object PythonGen { def loop( offsetIdent: Code.Ident, pat: List[StrPart], - next: Int + next: Int, + mustMatch: Boolean ): Env[ValueLike] = pat match { + case _ if mustMatch && next == bindArray.length => + // we have to match and we've captured everything + Env.pure(Code.Const.True) case Nil => // offset == str.length - Env.pure(offsetIdent =:= strEx.len()) + if (mustMatch) Env.pure(Code.Const.True) + else Env.pure(offsetIdent =:= strEx.len()) case LitStr(expect) :: tail => // val len = expect.length // str.regionMatches(offset, expect, 0, len) && loop(offset + len, tail, next) // // strEx.startswith(expect, offsetIdent) - loop(offsetIdent, tail, next) + // note: a literal string can never be a total match, so mustMatch is false + loop(offsetIdent, tail, next, mustMatch = false) .flatMap { loopRes => val regionMatches = strEx.dot(Code.Ident("startswith"))(expect, offsetIdent) @@ -1311,7 +1328,9 @@ object PythonGen { Env.andCode(regionMatches, rest) } case (c: CharPart) :: tail => - val matches = offsetIdent :< strEx.len() + val matches = + if (mustMatch) Code.Const.True + else offsetIdent :< strEx.len() val n1 = if (c.capture) (next + 1) else next val stmt = if (c.capture) { @@ -1324,7 +1343,7 @@ object PythonGen { .withValue(true) } else (offsetIdent := offsetIdent + 1).withValue(true) for { - tailRes <- loop(offsetIdent, tail, n1) + tailRes <- loop(offsetIdent, tail, n1, mustMatch) and2 <- Env.andCode(stmt, tailRes) and1 <- Env.andCode(matches, and2) } yield and1 @@ -1385,7 +1404,8 @@ object PythonGen { Env.newAssignableVar, Env.newAssignableVar ).mapN { (start, result, candidate, candOffset) => - val searchEnv = loop(candOffset, tail2, next1) + // note, a literal prefix can never be a total match + val searchEnv = loop(candOffset, tail2, next1, mustMatch = false) def onSearch(search: ValueLike): Env[Statement] = Env.ifElseS( @@ -1451,7 +1471,9 @@ object PythonGen { for { matched <- Env.newAssignableVar off1 <- Env.newAssignableVar - tailMatched <- loop(off1, tail, next1) + // the tail match isn't true, because we loop until we find + // a case + tailMatched <- loop(off1, tail, next1, false) matchStmt = Code .block( @@ -1462,7 +1484,7 @@ object PythonGen { matched := tailMatched // the tail match increments the ) ) - .withValue(matched) + .withValue(if (mustMatch) Code.Const.True else matched) fullMatch <- if (!h.capture) Env.pure(matchStmt) @@ -1486,10 +1508,38 @@ object PythonGen { } } - for { - offsetIdent <- Env.newAssignableVar - res <- loop(offsetIdent, pat, 0) - } yield (offsetIdent := 0).withValue(res) + pat match { + // handle some common special cases + case (c: StrPart.CharPart) :: Nil => + // single character + val matches = if (mustMatch) Code.Const.True else strEx.len() =:= 1 + if (c.capture) { + val stmt = bindArray(0) := Code.SelectItem(strEx, 0) + Env.ifElse( + NonEmptyList.one((matches, stmt.withValue(true))), + Code.Const.False) + } + else { + Env.pure(matches) + } + case StrPart.WildStr :: (c: StrPart.CharPart) :: Nil => + // last character + val matches = if (mustMatch) Code.Const.True else strEx.len() :> 0 + if (c.capture) { + val stmt = bindArray(0) := Code.SelectItem(strEx, -1) + Env.ifElse( + NonEmptyList.one((matches, stmt.withValue(true))), + Code.Const.False) + } + else { + Env.pure(matches) + } + case _ => + for { + offsetIdent <- Env.newAssignableVar + res <- loop(offsetIdent, pat, 0, mustMatch) + } yield (offsetIdent := 0).withValue(res) + } } def searchList( From be4d2d6019c97ef2c3a6c577a451571dfa065fc3 Mon Sep 17 00:00:00 2001 From: Scala Steward <43047562+scala-steward@users.noreply.github.com> Date: Thu, 21 Nov 2024 18:29:13 +0100 Subject: [PATCH 10/11] Update cats-effect to 3.5.6 (#1268) --- project/Dependencies.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/Dependencies.scala b/project/Dependencies.scala index d5a865d56..d8e40be7f 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -4,7 +4,7 @@ import org.portablescala.sbtplatformdeps.PlatformDepsPlugin.autoImport._ object Dependencies { lazy val cats = Def.setting("org.typelevel" %%% "cats-core" % "2.12.0") lazy val catsEffect = - Def.setting("org.typelevel" %%% "cats-effect" % "3.5.5") + Def.setting("org.typelevel" %%% "cats-effect" % "3.5.6") lazy val catsParse = Def.setting("org.typelevel" %%% "cats-parse" % "1.0.0") lazy val decline = Def.setting("com.monovore" %%% "decline" % "2.4.1") lazy val ff4s = Def.setting("io.github.buntec" %%% "ff4s" % "0.24.0") From 6fd56074f37a40f742c64febb974d1c3fd2f57db Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Thu, 21 Nov 2024 13:57:09 -1000 Subject: [PATCH 11/11] More improvements to MatchString codegen on Python (#1269) * More improvements to MatchString codegen on Python * remove done TODO * remove loop in MatchString with globs --- .../bykn/bosatsu/codegen/python/Code.scala | 11 +- .../bosatsu/codegen/python/PythonGen.scala | 235 +++++++++++++----- test_workspace/Char.bosatsu | 20 ++ 3 files changed, 198 insertions(+), 68 deletions(-) diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/python/Code.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/python/Code.scala index 39b040c27..4d7f60c0a 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/python/Code.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/python/Code.scala @@ -60,6 +60,9 @@ object Code { def =:=(that: Expression): Expression = Code.Op(this, Code.Const.Eq, that) + def =!=(that: Expression): Expression = + Code.Op(this, Code.Const.Neq, that) + def evalAnd(that: Expression): Expression = eval(Const.And, that) @@ -295,6 +298,10 @@ object Code { case PyBool(b) => PyBool(true ^ b) case Const.Zero => Const.True case Const.One => Const.False + case Op(left, Const.Eq, right) => + Op(left, Const.Neq, right) + case Op(left, Const.Neq, right) => + Op(left, Const.Eq, right) case other => Not(other) } @@ -550,7 +557,7 @@ object Code { notStatic case (Const.Zero | Const.False, Const.One | Const.True) => // this is just the not(condition) - Not(notStatic) + Not(notStatic).simplify case (st, sf) => Ternary(st, notStatic, sf) } @@ -715,6 +722,7 @@ object Code { lst match { case Nil => (Nil, Pass) case (Code.Const.True, last) :: _ => (Nil, last) + case (Code.Const.False, _) :: tail => untilTrue(tail) case head :: tail => val (rest, e) = untilTrue(tail) (head :: rest, e) @@ -1004,6 +1012,7 @@ object Code { case object Neq extends Operator("!=") case object Gt extends Operator(">") case object Lt extends Operator("<") + case object In extends Operator("in") val True = PyBool(true) val False = PyBool(false) diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala index cba0cb7d8..01d328098 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala @@ -277,7 +277,7 @@ object PythonGen { cond match { case cx: Expression => Env.pure(Code.ValueLike.ifThenElse(cx, tCase, fCase)) - case WithValue(cs, cv) if cv.returnsBool => + case WithValue(cs, cv) => // Nest into the condition ifElse1(cv, tCase, fCase).map(cs.withValue(_)) case ife@ IfElse(ccond, celse) if ife.returnsBool => @@ -354,13 +354,16 @@ object PythonGen { // and(x, y) == if x: y else: False Env.pure(Code.ValueLike.ifThenElse(e1, c2, Code.Const.False)) } - case (IfElse(cs, e), x2: Expression) => + case (ife @ IfElse(cs, e), x2) if ife.returnsBool || x2.isInstanceOf[Expression] => + // push down into the lhs since this won't increase the final branch count (cs.traverse { case (c, t) => andCode(t, x2).map(t => (c, t)) }, andCode(e, x2) ).mapN(IfElse(_, _)) case (WithValue(s, c1), c2) => andCode(c1, c2).map(s.withValue(_)) case _ => + // we don't nest pairs of IfElse if the tails can't be evaluated now + // since that could cause an exponential explosion of code size Env.onLastM(c1) { andCode(_, c2) } } @@ -1294,6 +1297,7 @@ object PythonGen { // return a value like expression that contains the boolean result // and assigns all the bindings along the way def loop( + knownPos: Option[Int], offsetIdent: Code.Ident, pat: List[StrPart], next: Int, @@ -1305,49 +1309,70 @@ object PythonGen { Env.pure(Code.Const.True) case Nil => // offset == str.length + val off = knownPos.fold(offsetIdent: Expression) { i => (i: Expression) } if (mustMatch) Env.pure(Code.Const.True) - else Env.pure(offsetIdent =:= strEx.len()) + else Env.pure(off =:= strEx.len()) case LitStr(expect) :: tail => // val len = expect.length // str.regionMatches(offset, expect, 0, len) && loop(offset + len, tail, next) // // strEx.startswith(expect, offsetIdent) // note: a literal string can never be a total match, so mustMatch is false - loop(offsetIdent, tail, next, mustMatch = false) + val expectSize = expect.codePointCount(0, expect.length) + loop(knownPos.map(_ + expectSize), offsetIdent, tail, next, mustMatch = false) .flatMap { loopRes => + val off = knownPos.fold(offsetIdent: Expression) { i => (i: Expression) } val regionMatches = - strEx.dot(Code.Ident("startswith"))(expect, offsetIdent) + strEx.dot(Code.Ident("startswith"))(expect, off) val rest = - ( - offsetIdent := offsetIdent + expect.codePointCount( - 0, - expect.length - ) - ).withValue(loopRes) + if (tail.nonEmpty && (tail != (StrPart.WildStr :: Nil))) { + // we aren't done matching + (offsetIdent := offsetIdent + expectSize) + .withValue(loopRes) + } + else { + // we are done matching, no need to update offset + loopRes + } Env.andCode(regionMatches, rest) } + case (c: CharPart) :: Nil => + // last character + val off = knownPos.fold(offsetIdent + 1) { i => (i + 1): Expression } + val matches = if (mustMatch) Code.Const.True else (strEx.len() =:= off) + if (c.capture) { + val stmt = bindArray(next) := Code.SelectItem(strEx, -1) + Env.andCode(matches, stmt.withValue(true)) + } + else { + Env.pure(matches) + } case (c: CharPart) :: tail => + val off = knownPos.fold(offsetIdent: Expression) { i => (i: Expression) } val matches = if (mustMatch) Code.Const.True - else offsetIdent :< strEx.len() + else (off :< strEx.len()) val n1 = if (c.capture) (next + 1) else next val stmt = if (c.capture) { // b = str[offset] Code .block( - bindArray(next) := Code.SelectItem(strEx, offsetIdent), + bindArray(next) := Code.SelectItem(strEx, off), offsetIdent := offsetIdent + 1 ) .withValue(true) } else (offsetIdent := offsetIdent + 1).withValue(true) for { - tailRes <- loop(offsetIdent, tail, n1, mustMatch) + tailRes <- loop(knownPos.map(_ + 1), offsetIdent, tail, n1, mustMatch) and2 <- Env.andCode(stmt, tailRes) and1 <- Env.andCode(matches, and2) } yield and1 case (h: Glob) :: tail => + // after a glob, we no longer know the knownPos + val off = knownPos.fold(offsetIdent: Expression) { i => (i: Expression) } + val knownPos1 = None tail match { case Nil => // we capture all the rest @@ -1355,16 +1380,90 @@ object PythonGen { if (h.capture) { // b = str[offset:] (bindArray(next) := Code - .SelectRange(strEx, Some(offsetIdent), None)) + .SelectRange(strEx, Some(off), None)) .withValue(true) } else Code.Const.True ) - case LitStr(expect) :: tail2 => + case LitStr(expect) :: Nil => + // if strEx.endswith(expect): + // h = strEx[off:-(expect.len)] + val elen = expect.codePointCount(0, expect.length) + val matches = + if (mustMatch) Code.Const.True + else strEx.dot(Code.Ident("endswith"))(Code.PyString(expect)) + + if (h.capture) { + Env.andCode(matches, + (binds(next) := Code.SelectRange(strEx, Some(off), Some(-elen: Expression))) + .withValue(true) + ) + } + else + Env.pure(matches) + case LitStr(expect) :: (g2: Glob) :: Nil => + // we could implement this kind of match: .*expect.* + // which is partition: + // (left, e, right) = strEx[offset:].partition(expect) + // if e: + // h = left + // e = right + // True + // else: + // False + val base = knownPos match { + case Some(0) => strEx + case _ => Code.SelectRange(strEx, Some(off), None) + } + if (h.capture || g2.capture) { + var npos = next + for { + pres <- Env.newAssignableVar + // TODO, maybe partition isn't actually ideal here + cmd = pres := base.dot(Code.Ident("partition"))(expect) + hbind = + if (h.capture) { + val b = npos + npos = npos + 1 + binds(b) := pres.get(0) + } else Code.Pass + gbind = + if (g2.capture) { + val b = npos + npos = npos + 1 + binds(b) := pres.get(2) + } else Code.Pass + cond = pres.get(1) =!= Code.PyString("") + m <- Env.andCode( + cmd.withValue(cond), + (hbind +: gbind).withValue(true)) + } yield m + } + else { + // this is just expect in strEx[off:] + Env.pure(knownPos match { + case Some(0) => + Code.Op(Code.PyString(expect), Code.Const.In, strEx) + case _ => + strEx.dot(Code.Ident("find"))(expect, off) :> (-1: Expression) + }) + } + case LitStr(expect) :: (tail2 @ (th :: _)) => + // well formed patterns can really only + // have Nil, Glob :: _, Char :: _ after LitStr + // since we should have combined adjacent LitStr + // // here we have to make a loop // searching for expect, and then see if we // can match the rest of the pattern val next1 = if (h.capture) next + 1 else next + // there is a glob before and after expect, + // so, if the tail, .*x can't match s + // then it can't match any suffix of s. + val shouldSearch = th match { + case _: Glob => false + case _ => true + } /* * this is the scala code for the below * it is in MatchlessToValue but left here @@ -1403,9 +1502,9 @@ object PythonGen { Env.newAssignableVar, Env.newAssignableVar, Env.newAssignableVar - ).mapN { (start, result, candidate, candOffset) => + ).flatMapN { (start, result, candidate, candOffset) => // note, a literal prefix can never be a total match - val searchEnv = loop(candOffset, tail2, next1, mustMatch = false) + val searchEnv = loop(knownPos1, candOffset, tail2, next1, mustMatch = false) def onSearch(search: ValueLike): Env[Statement] = Env.ifElseS( @@ -1415,7 +1514,7 @@ object PythonGen { if (h.capture) (bindArray(next) := Code.SelectRange( strEx, - Some(offsetIdent), + Some(off), Some(candidate) )) else Code.Pass @@ -1451,21 +1550,51 @@ object PythonGen { for { search <- searchEnv find <- findBranch(search) - } yield (Code - .block( - start := offsetIdent, - result := false, - Code.While( - (start :> -1), - Code.block( - candidate := strEx - .dot(Code.Ident("find"))(expect, start), - find + } yield if (shouldSearch) { + Code + .block( + start := off, + result := false, + Code.While( + (start :> -1), + Code.block( + candidate := strEx + .dot(Code.Ident("find"))(expect, start), + find + ) ) ) - ) - .withValue(result)) - }.flatten + .withValue(result) + } + else { + Code + .block( + start := off, + result := false, + candidate := strEx.dot(Code.Ident("find"))(expect, start), + find + ) + .withValue(result) + } + } + case (c: CharPart) :: Nil => + // last character + val matches = if (mustMatch) Code.Const.True else (strEx.len() :> off) + val cpart = if (c.capture) { + val cnext = if (h.capture) (next + 1) else next + val stmt = bindArray(cnext) := Code.SelectItem(strEx, -1) + Env.andCode(matches, stmt.withValue(true)) + } + else { + Env.pure(matches) + } + val hpart = + if (h.capture) { + bindArray(next) := Code.SelectRange(strEx, Some(off), Some(-1: Expression)) + } + else Code.Pass + + cpart.map(hpart.withValue(_)) case (_: CharPart) :: _ => val next1 = if (h.capture) (next + 1) else next for { @@ -1473,12 +1602,12 @@ object PythonGen { off1 <- Env.newAssignableVar // the tail match isn't true, because we loop until we find // a case - tailMatched <- loop(off1, tail, next1, false) + tailMatched <- loop(knownPos1, off1, tail, next1, false) matchStmt = Code .block( matched := false, - off1 := offsetIdent, + off1 := off, Code.While( (!matched).evalAnd(off1 :< strEx.len()), matched := tailMatched // the tail match increments the @@ -1492,7 +1621,7 @@ object PythonGen { val capture = Code .block( bindArray(next) := Code - .SelectRange(strEx, Some(offsetIdent), Some(off1)) + .SelectRange(strEx, Some(off), Some(off1)) ) .withValue(true) Env.andCode(matchStmt, capture) @@ -1508,38 +1637,10 @@ object PythonGen { } } - pat match { - // handle some common special cases - case (c: StrPart.CharPart) :: Nil => - // single character - val matches = if (mustMatch) Code.Const.True else strEx.len() =:= 1 - if (c.capture) { - val stmt = bindArray(0) := Code.SelectItem(strEx, 0) - Env.ifElse( - NonEmptyList.one((matches, stmt.withValue(true))), - Code.Const.False) - } - else { - Env.pure(matches) - } - case StrPart.WildStr :: (c: StrPart.CharPart) :: Nil => - // last character - val matches = if (mustMatch) Code.Const.True else strEx.len() :> 0 - if (c.capture) { - val stmt = bindArray(0) := Code.SelectItem(strEx, -1) - Env.ifElse( - NonEmptyList.one((matches, stmt.withValue(true))), - Code.Const.False) - } - else { - Env.pure(matches) - } - case _ => - for { - offsetIdent <- Env.newAssignableVar - res <- loop(offsetIdent, pat, 0, mustMatch) - } yield (offsetIdent := 0).withValue(res) - } + for { + offsetIdent <- Env.newAssignableVar + res <- loop(Some(0), offsetIdent, pat, 0, mustMatch) + } yield (offsetIdent := 0).withValue(res) } def searchList( diff --git a/test_workspace/Char.bosatsu b/test_workspace/Char.bosatsu index a690e6d4c..6a05d05ce 100644 --- a/test_workspace/Char.bosatsu +++ b/test_workspace/Char.bosatsu @@ -58,11 +58,31 @@ match_tests = TestSuite("match tests", Assertion("abc👋👋👋" matches "${_}👋", "test matching 6"), ]) +def starts_with_foo(s): s matches "foo${_}" +def ends_with_foo(s): s matches "${_}foo" +def contains_foo(s): s matches "${_}foo${_}" +def contains_foo_bar(s): s matches "${_}foo${_}bar${_}" + +glob_match_tests = TestSuite("glob_match_suites", + [ + Assertion(starts_with_foo("foobar"), "starts_with_foo(foobar)"), + Assertion(starts_with_foo("barfoo") matches False, "starts_with_foo(foobar)"), + Assertion(ends_with_foo("foobar") matches False, "ends_with_foo(foobar)"), + Assertion(ends_with_foo("barfoo"), "ends_with_foo(foobar)"), + Assertion(contains_foo("barfoo"), "contains_foo(foobar)"), + Assertion(contains_foo("barbar") matches False, "contains_foo(barbar)"), + Assertion(contains_foo_bar("there is foo and bar"), "there is foo and bar"), + Assertion(contains_foo_bar("there is foobar"), "there is foobar"), + Assertion(contains_foo_bar("there is foo but not the other") matches False, + "there is foo but not the other"), + ]) + tests = TestSuite("Char tests", [ str_to_char_tests, len_test, last_tests, match_tests, + glob_match_tests, ] ) \ No newline at end of file