diff --git a/cli/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala b/cli/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala new file mode 100644 index 000000000..11583c22e --- /dev/null +++ b/cli/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala @@ -0,0 +1,67 @@ +package org.bykn.bosatsu.codegen.clang + +import cats.data.NonEmptyList +import org.bykn.bosatsu.{PackageName, PackageMap, TestUtils, Identifier, Predef} +import Identifier.Name +import org.bykn.bosatsu.MatchlessFromTypedExpr + +import org.bykn.bosatsu.DirectEC.directEC + +class ClangGenTest extends munit.FunSuite { + val predef_c = Code.Include(true, "bosatsu_predef.h") + + def predef(s: String, arity: Int) = + (PackageName.PredefName -> Name(s)) -> (predef_c, + ClangGen.generatedName(PackageName.PredefName, Name(s)), + arity) + + val jvmExternals = { + val ext = Predef.jvmExternals.toMap.iterator.map { case ((_, n), ffi) => predef(n, ffi.arity) } + .toMap[(PackageName, Identifier), (Code.Include, Code.Ident, Int)] + + { (pn: (PackageName, Identifier)) => ext.get(pn) } + } + + def md5HashToHex(content: String): String = { + val md = java.security.MessageDigest.getInstance("MD5") + val digest = md.digest(content.getBytes("UTF-8")) + digest.map("%02x".format(_)).mkString + } + def testFilesCompilesToHash(path0: String, paths: String*)(hashHex: String)(implicit loc: munit.Location) = { + val pm: PackageMap.Typed[Any] = TestUtils.compileFile(path0, paths*) + /* + val exCode = ClangGen.generateExternalsStub(pm) + println(exCode.render(80)) + sys.error("stop") + */ + val matchlessMap = MatchlessFromTypedExpr.compile(pm) + val topoSort = pm.topoSort.toSuccess.get + val sortedEnv = cats.Functor[Vector].compose[NonEmptyList].map(topoSort) { pn => + (pn, matchlessMap(pn)) + } + + val res = ClangGen.renderMain( + sortedEnv = sortedEnv, + externals = jvmExternals, + value = (PackageName.PredefName, Identifier.Name("ignored")), + evaluator = (Code.Include(true, "eval.h"), Code.Ident("evaluator_run")) + ) + + res match { + case Right(d) => + val everything = d.render(80) + val hashed = md5HashToHex(everything) + assertEquals(hashed, hashHex, s"compilation didn't match. Compiled code:\n\n${"//" * 40}\n\n$everything") + case Left(e) => fail(e.toString) + } + } + + test("test_workspace/Ackermann.bosatsu") { + /* + To inspect the code, change the hash, and it will print the code out + */ + testFilesCompilesToHash("test_workspace/Ackermann.bosatsu")( + "46716ef3c97cf2a79bf17d4033d55854" + ) + } +} \ No newline at end of file diff --git a/cli/src/test/scala/org/bykn/bosatsu/codegen/python/PythonGenTest.scala b/cli/src/test/scala/org/bykn/bosatsu/codegen/python/PythonGenTest.scala index 7c6f197d5..391d98b67 100644 --- a/cli/src/test/scala/org/bykn/bosatsu/codegen/python/PythonGenTest.scala +++ b/cli/src/test/scala/org/bykn/bosatsu/codegen/python/PythonGenTest.scala @@ -1,17 +1,12 @@ package org.bykn.bosatsu.codegen.python import cats.Show -import cats.data.NonEmptyList import java.io.{ByteArrayInputStream, InputStream} -import java.nio.file.{Paths, Files} import java.util.concurrent.Semaphore import org.bykn.bosatsu.{ - PackageMap, MatchlessFromTypedExpr, - Parser, - Package, - LocationMap, - PackageName + PackageName, + TestUtils } import org.scalacheck.Gen import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ @@ -24,6 +19,8 @@ import org.python.core.{PyInteger, PyFunction, PyObject, PyTuple} import org.bykn.bosatsu.DirectEC.directEC import org.scalatest.funsuite.AnyFunSuite +import TestUtils.compileFile + // Jython seems to have some thread safety issues object JythonBarrier { private val sem = new Semaphore(1) @@ -87,27 +84,6 @@ class PythonGenTest extends AnyFunSuite { } } - def compileFile(path: String, rest: String*): PackageMap.Typed[Any] = { - def toS(s: String): String = - new String(Files.readAllBytes(Paths.get(s)), "UTF-8") - - val packNEL = - NonEmptyList(path, rest.toList) - .map { s => - val str = toS(s) - val pack = Parser.unsafeParse(Package.parser(None), str) - (("", LocationMap(str)), pack) - } - - val res = PackageMap.typeCheckParsed(packNEL, Nil, "") - res.left match { - case Some(err) => sys.error(err.toString) - case None => () - } - - res.right.get - } - def isfromString(s: String): InputStream = new ByteArrayInputStream(s.getBytes("UTF-8")) diff --git a/core/src/main/scala/org/bykn/bosatsu/FfiCall.scala b/core/src/main/scala/org/bykn/bosatsu/FfiCall.scala index 772642b5c..789a979d8 100644 --- a/core/src/main/scala/org/bykn/bosatsu/FfiCall.scala +++ b/core/src/main/scala/org/bykn/bosatsu/FfiCall.scala @@ -2,12 +2,12 @@ package org.bykn.bosatsu import cats.data.NonEmptyList -sealed abstract class FfiCall { +sealed abstract class FfiCall(val arity: Int) { def call(t: rankn.Type): Value } object FfiCall { - final case class Fn1(fn: Value => Value) extends FfiCall { + final case class Fn1(fn: Value => Value) extends FfiCall(1) { import Value.FnValue private[this] val evalFn: FnValue = FnValue { case NonEmptyList(a, _) => @@ -16,7 +16,7 @@ object FfiCall { def call(t: rankn.Type): Value = evalFn } - final case class Fn2(fn: (Value, Value) => Value) extends FfiCall { + final case class Fn2(fn: (Value, Value) => Value) extends FfiCall(2) { import Value.FnValue private[this] val evalFn: FnValue = @@ -26,7 +26,7 @@ object FfiCall { def call(t: rankn.Type): Value = evalFn } - final case class Fn3(fn: (Value, Value, Value) => Value) extends FfiCall { + final case class Fn3(fn: (Value, Value, Value) => Value) extends FfiCall(3) { import Value.FnValue private[this] val evalFn: FnValue = @@ -37,10 +37,6 @@ object FfiCall { def call(t: rankn.Type): Value = evalFn } - final case class FromFn(callFn: rankn.Type => Value) extends FfiCall { - def call(t: rankn.Type): Value = callFn(t) - } - def getJavaType(t: rankn.Type): List[Class[_]] = { def one(t: rankn.Type): Option[Class[_]] = loop(t, false) match { diff --git a/core/src/main/scala/org/bykn/bosatsu/MainModule.scala b/core/src/main/scala/org/bykn/bosatsu/MainModule.scala index 8f60ad42f..2fadd7efa 100644 --- a/core/src/main/scala/org/bykn/bosatsu/MainModule.scala +++ b/core/src/main/scala/org/bykn/bosatsu/MainModule.scala @@ -630,7 +630,7 @@ abstract class MainModule[IO[_]](implicit val intrinsic = PythonGen.intrinsicValues val missingExternals = allExternals.iterator.flatMap { case (p, names) => - val missing = names.filterNot { case n => + val missing = names.filterNot { case (n, _) => exts((p, n)) || intrinsic.get(p).exists(_(n)) } @@ -703,7 +703,7 @@ abstract class MainModule[IO[_]](implicit Doc.char('[') + Doc.intercalate( Doc.comma + Doc.lineOrSpace, - names.map(b => Doc.text(b.sourceCodeRepr)) + names.map { case (b, _) => Doc.text(b.sourceCodeRepr) } ) + Doc.char(']')).nested(4) } diff --git a/core/src/main/scala/org/bykn/bosatsu/PackageMap.scala b/core/src/main/scala/org/bykn/bosatsu/PackageMap.scala index aca53813e..29ad31053 100644 --- a/core/src/main/scala/org/bykn/bosatsu/PackageMap.scala +++ b/core/src/main/scala/org/bykn/bosatsu/PackageMap.scala @@ -45,13 +45,17 @@ case class PackageMap[A, B, C, +D]( def allExternals(implicit ev: Package[A, B, C, D] <:< Package.Typed[Any] - ): Map[PackageName, List[Identifier.Bindable]] = + ): Map[PackageName, List[(Identifier.Bindable, rankn.Type)]] = toMap.iterator.map { case (name, pack) => - (name, ev(pack).externalDefs) + val tpack = ev(pack) + (name, tpack.externalDefs.map { n => + (n, tpack.types.getExternalValue(name, n) + .getOrElse(sys.error(s"invariant violation, unknown type: $name $n")) ) + }) }.toMap def topoSort( - ev: Package[A, B, C, D] <:< Package.Typed[Any] + implicit ev: Package[A, B, C, D] <:< Package.Typed[Any] ): Toposort.Result[PackageName] = { val packNames = toMap.keys.iterator.toList.sorted diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala index 933d323c1..2867f9507 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala @@ -5,8 +5,8 @@ import cats.data.{StateT, EitherT, NonEmptyList, Chain} import java.math.BigInteger import java.nio.charset.StandardCharsets import org.bykn.bosatsu.codegen.Idents -import org.bykn.bosatsu.rankn.DataRepr -import org.bykn.bosatsu.{Identifier, Lit, Matchless, PackageName} +import org.bykn.bosatsu.rankn.{DataRepr, Type} +import org.bykn.bosatsu.{Identifier, Lit, Matchless, PackageName, PackageMap} import org.bykn.bosatsu.Matchless.Expr import org.bykn.bosatsu.Identifier.Bindable import org.typelevel.paiges.Doc @@ -21,9 +21,43 @@ object ClangGen { case class Unbound(bn: Bindable, inside: Option[(PackageName, Bindable)]) extends Error } + def generateExternalsStub(pm: PackageMap.Typed[Any]): Doc = { + val includes = Code.Include(true, "bosatsu_runtime.h") :: Nil + + def toStmt(pn: PackageName, ident: Identifier.Bindable, arity: Int): Code.Statement = { + val cIdent = generatedName(pn, ident) + val args = Idents.allSimpleIdents.take(arity).map { nm => + Code.Param(Code.TypeIdent.BValue, Code.Ident(nm)) + } + Code.DeclareFn(Nil, Code.TypeIdent.BValue, cIdent, args.toList, Some( + Code.block(Code.Return(Some(Code.IntLiteral.Zero))) + )) + } + + def tpeArity(t: Type): Int = + t match { + case Type.Fun.MaybeQuant(_, args, _) => args.length + case _ => 0 + } + + val fns = pm.allExternals + .iterator + .flatMap { case (p, vs) => + vs.iterator.map { case (n, tpe) => + Code.toDoc(toStmt(p, n, tpeArity(tpe))) + } + } + .toList + + val line2 = Doc.hardLine + Doc.hardLine + + Doc.intercalate(Doc.hardLine, includes.map(Code.toDoc)) + line2 + + Doc.intercalate(line2, fns) + } + def renderMain( sortedEnv: Vector[NonEmptyList[(PackageName, List[(Bindable, Expr)])]], - externals: Map[(PackageName, Bindable), (Code.Include, Code.Ident)], + externals: ((PackageName, Bindable)) => Option[(Code.Include, Code.Ident, Int)], value: (PackageName, Bindable), evaluator: (Code.Include, Code.Ident) ): Either[Error, Doc] = { @@ -44,7 +78,7 @@ object ClangGen { .iterator.flatMap(_.iterator) .flatMap { case (p, vs) => vs.iterator.map { case (b, e) => - (p, b) -> (e, Impl.generatedName(p, b)) + (p, b) -> (e, generatedName(p, b)) } } .toMap @@ -52,15 +86,15 @@ object ClangGen { run(allValues, externals, res) } - private object Impl { - type AllValues = Map[(PackageName, Bindable), (Expr, Code.Ident)] - type Externals = Map[(PackageName, Bindable), (Code.Include, Code.Ident)] + private def fullName(p: PackageName, b: Bindable): String = + p.asString + "/" + b.asString - def fullName(p: PackageName, b: Bindable): String = - p.asString + "/" + b.asString + def generatedName(p: PackageName, b: Bindable): Code.Ident = + Code.Ident(Idents.escape("___bsts_g_", fullName(p, b))) - def generatedName(p: PackageName, b: Bindable): Code.Ident = - Code.Ident(Idents.escape("___bsts_g_", fullName(p, b))) + private object Impl { + type AllValues = Map[(PackageName, Bindable), (Expr, Code.Ident)] + type Externals = Function1[(PackageName, Bindable), Option[(Code.Include, Code.Ident, Int)]] trait Env { import Matchless._ @@ -410,11 +444,7 @@ object ClangGen { case Some(nm) => pv(Code.Ident("STATIC_PUREFN")(nm)) case None => - // read_or_build(&__bvalue_foo, make_foo); - for { - value <- staticValueName(pack, name) - consFn <- constructorFn(pack, name) - } yield Code.Ident("read_or_build")(value.addr, consFn): Code.ValueLike + globalIdent(pack, name).map { nm => nm() } } case Local(arg) => directFn(arg) @@ -494,7 +524,7 @@ object ClangGen { case ZeroNat => pv(Code.Ident("BSTS_NAT_0")) case SuccNat => - val arg = Identifier.Name("arg0") + val arg = Identifier.Name("nat") // This relies on optimizing App(SuccNat, _) otherwise // it creates an infinite loop. // Also, this we should cache creation of Lambda/Closure values @@ -567,18 +597,21 @@ object ClangGen { _ <- appendStatement(stmt) } yield () case someValue => + // TODO: if we can create the value statically, we don't + // need the read_or_build trick + // // we materialize an Atomic value to hold the static data // then we generate a function to populate the value for { vl <- innerToValue(someValue) value <- staticValueName(p, b) - consFn <- constructorFn(p, b) _ <- appendStatement(Code.DeclareVar( Code.Attr.Static :: Nil, Code.TypeIdent.AtomicBValue, value, Some(Code.IntLiteral.Zero) )) + consFn <- constructorFn(p, b) _ <- appendStatement(Code.DeclareFn( Code.Attr.Static :: Nil, Code.TypeIdent.BValue, @@ -586,6 +619,15 @@ object ClangGen { Nil, Some(Code.block(Code.returnValue(vl))) )) + readFn <- globalIdent(p, b) + res = Code.Ident("read_or_build")(value.addr, consFn) + _ <- appendStatement(Code.DeclareFn( + Code.Attr.Static :: Nil, + Code.TypeIdent.BValue, + readFn, + Nil, + Some(Code.block(Code.returnValue(res))) + )) } yield () } } @@ -652,8 +694,9 @@ object ClangGen { def globalIdent(pn: PackageName, bn: Bindable): T[Code.Ident] = StateT { s => val key = (pn, bn) - s.externals.get(key) match { - case Some((incl, ident)) => + s.externals(key) match { + case Some((incl, ident, _)) => + // TODO: suspect that we are ignoring arity here val withIncl = if (s.includeSet(incl)) s else s.copy(includeSet = s.includeSet + incl, includes = s.includes :+ incl) @@ -775,9 +818,21 @@ object ClangGen { // record that this name is a top level function, so applying it can be direct def directFn(pack: PackageName, b: Bindable): T[Option[Code.Ident]] = StateT { s => - s.allValues.get((pack, b)) match { + val key = (pack, b) + s.allValues.get(key) match { case Some((_: Matchless.FnExpr, ident)) => result(s, Some(ident)) + case None => + // this is external + s.externals(key) match { + case Some((incl, ident, arity)) if arity > 0 => + // TODO: suspect that we are ignoring arity here + val withIncl = + if (s.includeSet(incl)) s + else s.copy(includeSet = s.includeSet + incl, includes = s.includes :+ incl) + result(withIncl, Some(ident)) + case _ => result(s, None) + } case _ => result(s, None) } } diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala index 565be64b4..2992894f7 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala @@ -401,8 +401,8 @@ object Code { private val doDoc = Doc.text("do ") private val whileDoc = Doc.text("while") private val arrow = Doc.text("->") - private val questionDoc = Doc.text(" ? ") - private val colonDoc = Doc.text(" : ") + private val questionDoc = Doc.text(" ?") + Doc.line + private val colonDoc = Doc.text(" :") + Doc.line private val quoteDoc = Doc.char('"') private def leftAngleDoc = BinOp.Lt.toDoc private def rightAngleDoc = BinOp.Gt.toDoc @@ -506,7 +506,7 @@ object Code { case noPar @ (Tight(_) | PrefixExpr(_, _) | BinExpr(_, _, _)) => toDoc(noPar) case yesPar => par(toDoc(yesPar)) } - d(cond) + questionDoc + d(t) + colonDoc + d(f) + (d(cond) + (questionDoc + d(t) + colonDoc + d(f)).nested(4)).grouped // Statements case Assignment(t, v) => toDoc(t) + (equalsDoc + (toDoc(v) + semiDoc)) case DeclareArray(tpe, nm, values) => diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala index 39c21f860..0320ee14f 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala @@ -923,6 +923,15 @@ object Type { else None } + object MaybeQuant { + def unapply(t: Type): Option[(Option[Quantification], NonEmptyList[Type], Type)] = + t match { + case Quantified(quant, Fun(args, res)) => Some((Some(quant), args, res)) + case Fun(args, res) => Some((None, args, res)) + case _ => None + } + } + def unapply(t: Type): Option[(NonEmptyList[Type], Type)] = { def check( n: Int, diff --git a/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala b/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala index 9fb15c714..a608b093f 100644 --- a/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala +++ b/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala @@ -1,13 +1,15 @@ package org.bykn.bosatsu -import cats.data.{Ior, Validated} -import cats.implicits._ +import cats.data.{Ior, Validated, NonEmptyList} +import java.nio.file.{Files, Paths} import org.bykn.bosatsu.rankn._ import org.scalatest.{Assertion, Assertions} import Assertions.{succeed, fail} import IorMethods.IorExtension +import cats.syntax.all._ + object TestUtils { def parsedTypeEnvOf( @@ -128,6 +130,28 @@ object TestUtils { } } + def compileFile(path: String, rest: String*)(implicit ec: Par.EC): PackageMap.Typed[Any] = { + def toS(s: String): String = + new String(Files.readAllBytes(Paths.get(s)), "UTF-8") + + val packNEL = + NonEmptyList(path, rest.toList) + .map { s => + val str = toS(s) + val pack = Parser.unsafeParse(Package.parser(None), str) + (("", LocationMap(str)), pack) + } + + val res = PackageMap.typeCheckParsed(packNEL, Nil, "") + res.left match { + case Some(err) => sys.error(err.toString) + case None => () + } + + res.right.get + } + + def makeInputArgs(files: List[(Int, Any)]): List[String] = ("--package_root" :: Int.MaxValue.toString :: Nil) ::: files.flatMap { case (idx, _) => "--input" :: idx.toString :: Nil diff --git a/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala b/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala index b20432673..4cd00e10f 100644 --- a/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala @@ -1,16 +1,23 @@ package org.bykn.bosatsu.codegen.clang import cats.data.NonEmptyList -import org.bykn.bosatsu.codegen.Idents import org.bykn.bosatsu.{PackageName, TestUtils, Identifier, Predef} import Identifier.Name class ClangGenTest extends munit.FunSuite { val predef_c = Code.Include(true, "bosatsu_predef.h") - def predef(s: String) = + def predef(s: String, arity: Int) = (PackageName.PredefName -> Name(s)) -> (predef_c, - Code.Ident(Idents.escape("__bsts_predef_", s))) + ClangGen.generatedName(PackageName.PredefName, Name(s)), + arity) + + val jvmExternals = { + val ext = Predef.jvmExternals.toMap.iterator.map { case ((_, n), ffi) => predef(n, ffi.arity) } + .toMap[(PackageName, Identifier), (Code.Include, Code.Ident, Int)] + + { (pn: (PackageName, Identifier)) => ext.get(pn) } + } def assertPredefFns(fns: String*)(matches: String)(implicit loc: munit.Location) = TestUtils.checkMatchless(""" @@ -30,8 +37,7 @@ x = 1 sortedEnv = Vector( NonEmptyList.one(PackageName.PredefName -> matchlessMap(PackageName.PredefName)), ), - externals = - Predef.jvmExternals.toMap.keys.iterator.map { case (_, n) => predef(n) }.toMap, + externals = jvmExternals, value = (PackageName.PredefName, Identifier.Name(fns.last)), evaluator = (Code.Include(true, "eval.h"), Code.Ident("evaluator_run")) ) @@ -42,6 +48,12 @@ x = 1 } } + def md5HashToHex(content: String): String = { + val md = java.security.MessageDigest.getInstance("MD5") + val digest = md.digest(content.getBytes("UTF-8")) + digest.map("%02x".format(_)).mkString + } + test("check build_List") { assertPredefFns("build_List")("""#include "bosatsu_runtime.h"