From f6cde9224f404cbcb88f7b300ced437ad794323c Mon Sep 17 00:00:00 2001 From: Scala Steward Date: Mon, 18 Sep 2023 17:26:24 +0000 Subject: [PATCH] Reformat with scalafmt 3.7.14 Executed command: scalafmt --non-interactive --- .../main/scala/org/bykn/bosatsu/Macro.scala | 20 +- .../scala/org/bykn/bosatsu/TestBench.scala | 31 +- .../scala/org/bykn/bosatsu/PathModule.scala | 82 +- .../org/bykn/bosatsu/TypedExprToProto.scala | 1177 ++++++++---- .../scala/org/bykn/bosatsu/JsonTest.scala | 13 +- .../org/bykn/bosatsu/PathModuleTest.scala | 160 +- .../org/bykn/bosatsu/TestProtoType.scala | 134 +- .../bosatsu/codegen/python/CodeTest.scala | 199 +- .../codegen/python/PythonGenTest.scala | 73 +- .../src/main/scala/org/bykn/bosatsu/Par.scala | 12 +- .../src/main/scala/org/bykn/bosatsu/Par.scala | 15 +- .../org/bykn/bosatsu/BindingStatement.scala | 10 +- .../org/bykn/bosatsu/CollectionUtils.scala | 31 +- .../org/bykn/bosatsu/CommentStatement.scala | 24 +- .../scala/org/bykn/bosatsu/Declaration.scala | 1088 ++++++----- .../org/bykn/bosatsu/DefRecursionCheck.scala | 322 ++-- .../scala/org/bykn/bosatsu/DefStatement.scala | 18 +- .../scala/org/bykn/bosatsu/EditDistance.scala | 6 +- .../scala/org/bykn/bosatsu/Evaluation.scala | 115 +- .../scala/org/bykn/bosatsu/ExportedName.scala | 155 +- .../main/scala/org/bykn/bosatsu/Expr.scala | 198 +- .../main/scala/org/bykn/bosatsu/FfiCall.scala | 13 +- .../src/main/scala/org/bykn/bosatsu/Fix.scala | 9 +- .../scala/org/bykn/bosatsu/Identifier.scala | 49 +- .../main/scala/org/bykn/bosatsu/Import.scala | 55 +- .../scala/org/bykn/bosatsu/Indented.scala | 20 +- .../scala/org/bykn/bosatsu/IorMethods.scala | 4 +- .../main/scala/org/bykn/bosatsu/Json.scala | 63 +- .../main/scala/org/bykn/bosatsu/Kind.scala | 6 +- .../scala/org/bykn/bosatsu/ListLang.scala | 95 +- .../src/main/scala/org/bykn/bosatsu/Lit.scala | 13 +- .../scala/org/bykn/bosatsu/LocationMap.scala | 73 +- .../scala/org/bykn/bosatsu/Matchless.scala | 437 +++-- .../bykn/bosatsu/MatchlessFromTypedExpr.scala | 47 +- .../org/bykn/bosatsu/MatchlessToValue.scala | 125 +- .../scala/org/bykn/bosatsu/MemoryMain.scala | 64 +- .../scala/org/bykn/bosatsu/NameKind.scala | 29 +- .../scala/org/bykn/bosatsu/Operators.scala | 88 +- .../scala/org/bykn/bosatsu/OptIndent.scala | 24 +- .../main/scala/org/bykn/bosatsu/Package.scala | 357 ++-- .../scala/org/bykn/bosatsu/PackageError.scala | 681 ++++--- .../scala/org/bykn/bosatsu/PackageMap.scala | 483 +++-- .../scala/org/bykn/bosatsu/PackageName.scala | 5 +- .../main/scala/org/bykn/bosatsu/Padding.scala | 18 +- .../main/scala/org/bykn/bosatsu/Parser.scala | 196 +- .../main/scala/org/bykn/bosatsu/PathGen.scala | 25 +- .../main/scala/org/bykn/bosatsu/Pattern.scala | 588 +++--- .../main/scala/org/bykn/bosatsu/Predef.scala | 99 +- .../main/scala/org/bykn/bosatsu/Program.scala | 9 +- .../scala/org/bykn/bosatsu/Referant.scala | 112 +- .../org/bykn/bosatsu/SourceConverter.scala | 1591 +++++++++------- .../scala/org/bykn/bosatsu/Statement.scala | 414 +++-- .../scala/org/bykn/bosatsu/StringUtil.scala | 66 +- .../main/scala/org/bykn/bosatsu/Test.scala | 34 +- .../org/bykn/bosatsu/TotalityCheck.scala | 462 +++-- .../scala/org/bykn/bosatsu/TypeParser.scala | 78 +- .../main/scala/org/bykn/bosatsu/TypeRef.scala | 89 +- .../org/bykn/bosatsu/TypeRefConverter.scala | 51 +- .../scala/org/bykn/bosatsu/TypedExpr.scala | 749 ++++---- .../bykn/bosatsu/TypedExprNormalization.scala | 402 +++-- .../org/bykn/bosatsu/UnusedLetCheck.scala | 81 +- .../main/scala/org/bykn/bosatsu/Value.scala | 157 +- .../scala/org/bykn/bosatsu/ValueToDoc.scala | 162 +- .../scala/org/bykn/bosatsu/ValueToJson.scala | 570 +++--- .../scala/org/bykn/bosatsu/Variance.scala | 49 +- .../bykn/bosatsu/codegen/python/Code.scala | 262 ++- .../bosatsu/codegen/python/PythonGen.scala | 1161 +++++++----- .../scala/org/bykn/bosatsu/graph/Dag.scala | 3 +- .../org/bykn/bosatsu/graph/Memoize.scala | 45 +- .../scala/org/bykn/bosatsu/graph/Paths.scala | 37 +- .../org/bykn/bosatsu/graph/Toposort.scala | 42 +- .../scala/org/bykn/bosatsu/graph/Tree.scala | 39 +- .../org/bykn/bosatsu/pattern/Matcher.scala | 17 +- .../bosatsu/pattern/NamedSeqPattern.scala | 109 +- .../org/bykn/bosatsu/pattern/SeqPart.scala | 34 +- .../org/bykn/bosatsu/pattern/SeqPattern.scala | 278 +-- .../org/bykn/bosatsu/pattern/SetOps.scala | 112 +- .../org/bykn/bosatsu/pattern/Splitter.scala | 41 +- .../org/bykn/bosatsu/rankn/DataRepr.scala | 9 +- .../org/bykn/bosatsu/rankn/DefinedType.scala | 72 +- .../scala/org/bykn/bosatsu/rankn/Infer.scala | 1253 ++++++++----- .../bykn/bosatsu/rankn/ParsedTypeEnv.scala | 11 +- .../scala/org/bykn/bosatsu/rankn/Ref.scala | 52 +- .../scala/org/bykn/bosatsu/rankn/Type.scala | 320 ++-- .../org/bykn/bosatsu/rankn/TypeEnv.scala | 136 +- .../bykn/bosatsu/CollectionUtilsTest.scala | 9 +- .../org/bykn/bosatsu/DeclarationTest.scala | 142 +- .../bykn/bosatsu/DefRecursionCheckTest.scala | 8 +- .../org/bykn/bosatsu/EvaluationTest.scala | 1601 ++++++++++++----- .../scala/org/bykn/bosatsu/FreeVarTest.scala | 19 +- .../src/test/scala/org/bykn/bosatsu/Gen.scala | 927 +++++++--- .../test/scala/org/bykn/bosatsu/GenJson.scala | 34 +- .../scala/org/bykn/bosatsu/GenValue.scala | 11 +- .../test/scala/org/bykn/bosatsu/IntLaws.scala | 71 +- .../scala/org/bykn/bosatsu/JsonTest.scala | 82 +- .../org/bykn/bosatsu/KindFormulaTest.scala | 14 +- .../org/bykn/bosatsu/LocationMapTest.scala | 21 +- .../org/bykn/bosatsu/MatchlessTests.scala | 75 +- .../scala/org/bykn/bosatsu/MonadGen.scala | 2 +- .../scala/org/bykn/bosatsu/OperatorTest.scala | 51 +- .../scala/org/bykn/bosatsu/PackageTest.scala | 39 +- .../test/scala/org/bykn/bosatsu/ParTest.scala | 2 +- .../scala/org/bykn/bosatsu/ParserTest.scala | 1227 +++++++++---- .../scala/org/bykn/bosatsu/PatternTest.scala | 37 +- .../org/bykn/bosatsu/SelfCallKindTest.scala | 18 +- .../bykn/bosatsu/SourceConverterTest.scala | 64 +- .../scala/org/bykn/bosatsu/TestUtils.scala | 140 +- .../scala/org/bykn/bosatsu/TotalityTest.scala | 356 ++-- .../scala/org/bykn/bosatsu/TypeRefTest.scala | 19 +- .../org/bykn/bosatsu/TypedExprTest.scala | 325 ++-- .../scala/org/bykn/bosatsu/ValueTest.scala | 11 +- .../org/bykn/bosatsu/ValueToDocTest.scala | 62 +- .../scala/org/bykn/bosatsu/VarianceTest.scala | 25 +- .../codegen/python/PythonGenTest.scala | 20 +- .../org/bykn/bosatsu/graph/ToposortTest.scala | 42 +- .../org/bykn/bosatsu/graph/TreeTest.scala | 34 +- .../bykn/bosatsu/pattern/SeqPatternTest.scala | 349 +++- .../org/bykn/bosatsu/pattern/SetOpsLaws.scala | 157 +- .../pattern/StringSeqPatternSetLaws.scala | 68 +- .../org/bykn/bosatsu/rankn/NTypeGen.scala | 61 +- .../bykn/bosatsu/rankn/RankNInferTest.scala | 749 +++++--- .../org/bykn/bosatsu/rankn/TypeTest.scala | 81 +- .../scala/org/bykn/bosatsu/jsapi/JsApi.scala | 48 +- project/plugins.sbt | 1 - 124 files changed, 14926 insertions(+), 8734 deletions(-) diff --git a/base/src/main/scala/org/bykn/bosatsu/Macro.scala b/base/src/main/scala/org/bykn/bosatsu/Macro.scala index 6e09b1740..623b5f2f6 100644 --- a/base/src/main/scala/org/bykn/bosatsu/Macro.scala +++ b/base/src/main/scala/org/bykn/bosatsu/Macro.scala @@ -15,14 +15,15 @@ class Macro(val c: Context) { if (f.exists()) { val res = Source.fromFile(s, "UTF-8").getLines().mkString("\n") Some(c.Expr[String](q"$res")) - } - else { + } else { None } - } - catch { + } catch { case NonFatal(err) => - c.abort(c.enclosingPosition, s"could not read existing file: $s. Exception: $err") + c.abort( + c.enclosingPosition, + s"could not read existing file: $s. Exception: $err" + ) } file.tree match { @@ -34,11 +35,14 @@ class Macro(val c: Context) { .getOrElse { c.abort( c.enclosingPosition, - s"no file found at: $s. working directory is ${System.getProperty("user.dir")}") + s"no file found at: $s. working directory is ${System.getProperty("user.dir")}" + ) } case otherTree => - c.abort(c.enclosingPosition, s"expected string literal, found: $otherTree") + c.abort( + c.enclosingPosition, + s"expected string literal, found: $otherTree" + ) } } } - diff --git a/bench/src/main/scala/org/bykn/bosatsu/TestBench.scala b/bench/src/main/scala/org/bykn/bosatsu/TestBench.scala index 9ceedc3d1..b1aa3233e 100644 --- a/bench/src/main/scala/org/bykn/bosatsu/TestBench.scala +++ b/bench/src/main/scala/org/bykn/bosatsu/TestBench.scala @@ -12,7 +12,10 @@ class TestBench { // don't use threads in the benchmark which will complicate matters import DirectEC.directEC - private def prepPackages(packages: List[String], mainPackS: String): (PackageMap.Inferred, PackageName) = { + private def prepPackages( + packages: List[String], + mainPackS: String + ): (PackageMap.Inferred, PackageName) = { val mainPack = PackageName.parse(mainPackS).get val parsed = packages.zipWithIndex.traverse { case (pack, i) => @@ -28,11 +31,18 @@ class TestBench { val d = p.showContext(LocationMap.Colorize.None) System.err.println(d.render(100)) } - sys.error("failed to parse") //errs.toString) + sys.error("failed to parse") // errs.toString) } - implicit val show: Show[(String, LocationMap)] = Show.show { case (s, _) => s } - PackageMap.resolveThenInfer(PackageMap.withPredefA(("predef", LocationMap("")), parsedPaths), Nil).strictToValidated match { + implicit val show: Show[(String, LocationMap)] = Show.show { case (s, _) => + s + } + PackageMap + .resolveThenInfer( + PackageMap.withPredefA(("predef", LocationMap("")), parsedPaths), + Nil + ) + .strictToValidated match { case Validated.Valid(packMap) => (packMap, mainPack) case other => sys.error(s"expected clean compilation: $other") @@ -40,11 +50,13 @@ class TestBench { } def gauss(n: Int) = prepPackages( - List(s""" + List(s""" package Gauss gauss$n = range($n).foldLeft(0, add) -"""), "Gauss") +"""), + "Gauss" + ) val compiled0: (PackageMap.Inferred, PackageName) = gauss(10) @@ -69,7 +81,8 @@ gauss$n = range($n).foldLeft(0, add) } val compiled2 = - prepPackages(List(""" + prepPackages( + List(""" package Euler4 def operator >(a, b): @@ -132,7 +145,9 @@ max_pal_opt = max_of(99, \n1 -> first_of(99, product_palindrome(n1))) max_pal = match max_pal_opt: Some(m): m None: 0 -"""), "Euler4") +"""), + "Euler4" + ) @Benchmark def bench2(): Unit = { val c = compiled2 diff --git a/cli/src/main/scala/org/bykn/bosatsu/PathModule.scala b/cli/src/main/scala/org/bykn/bosatsu/PathModule.scala index 1c6de286c..2ef23ac03 100644 --- a/cli/src/main/scala/org/bykn/bosatsu/PathModule.scala +++ b/cli/src/main/scala/org/bykn/bosatsu/PathModule.scala @@ -39,7 +39,10 @@ object PathModule extends MainModule[IO] { def readInterfaces(paths: List[Path]): IO[List[Package.Interface]] = ProtoConverter.readInterfaces(paths) - def writeInterfaces(interfaces: List[Package.Interface], path: Path): IO[Unit] = + def writeInterfaces( + interfaces: List[Package.Interface], + path: Path + ): IO[Unit] = ProtoConverter.writeInterfaces(interfaces, path) def writePackages[A](packages: List[Package.Typed[A]], path: Path): IO[Unit] = @@ -54,14 +57,14 @@ object PathModule extends MainModule[IO] { Some(IO { f.listFiles.iterator.map(_.toPath).toList }) - } - else None + } else None } } } - def hasExtension(str: String): Path => Boolean = - { (path: Path) => path.toString.endsWith(str) } + def hasExtension(str: String): Path => Boolean = { (path: Path) => + path.toString.endsWith(str) + } def print(str: => String): IO[Unit] = IO(println(str)) @@ -71,27 +74,37 @@ object PathModule extends MainModule[IO] { def report(io: IO[Output]): IO[ExitCode] = io.attempt.flatMap { case Right(out) => reportOutput(out) - case Left(err) => reportException(err).as(ExitCode.Error) + case Left(err) => reportException(err).as(ExitCode.Error) } def reportOutput(out: Output): IO[ExitCode] = out match { case Output.TestOutput(resMap, color) => val noTests = resMap.collect { case (p, None) => p }.toList - val results = resMap.collect { case (p, Some(t)) => (p, Test.report(t.value, color)) }.toList.sortBy(_._1) + val results = resMap + .collect { case (p, Some(t)) => (p, Test.report(t.value, color)) } + .toList + .sortBy(_._1) val successes = results.iterator.map { case (_, (s, _, _)) => s }.sum val failures = results.iterator.map { case (_, (_, f, _)) => f }.sum val success = noTests.isEmpty && (failures == 0) val suffix = - if (results.lengthCompare(1) > 0) (Doc.hardLine + Doc.hardLine + Test.summary(successes, failures, color)) + if (results.lengthCompare(1) > 0) + (Doc.hardLine + Doc.hardLine + Test.summary( + successes, + failures, + color + )) else Doc.empty val docRes: Doc = - Doc.intercalate(Doc.hardLine + Doc.hardLine, + Doc.intercalate( + Doc.hardLine + Doc.hardLine, results.map { case (p, (_, _, d)) => - Doc.text(p.asString) + Doc.char(':') + (Doc.lineOrSpace + d).nested(2) - }) + suffix - + Doc.text(p.asString) + Doc.char(':') + (Doc.lineOrSpace + d) + .nested(2) + } + ) + suffix if (success) print(docRes.render(80)).as(ExitCode.Success) else { @@ -99,11 +112,17 @@ object PathModule extends MainModule[IO] { if (noTests.isEmpty) Nil else { val prefix = Doc.text("packages with missing tests: ") - val missingDoc = Doc.intercalate(Doc.comma + Doc.lineOrSpace, noTests.sorted.map { p => Doc.text(p.asString) }) + val missingDoc = Doc.intercalate( + Doc.comma + Doc.lineOrSpace, + noTests.sorted.map { p => Doc.text(p.asString) } + ) (prefix + missingDoc.nested(2)) :: Nil } - val fullOut = Doc.intercalate(Doc.hardLine + Doc.hardLine + (Doc.char('#') * 80) + Doc.line, docRes :: missingDoc) + val fullOut = Doc.intercalate( + Doc.hardLine + Doc.hardLine + (Doc.char('#') * 80) + Doc.line, + docRes :: missingDoc + ) val failureStr = if (failures == 1) "1 test failure" @@ -114,33 +133,37 @@ object PathModule extends MainModule[IO] { if (missingCount > 0) { val packString = if (missingCount == 1) "package" else "packages" s"$failureStr and $missingCount $packString with no tests found" - } - else failureStr + } else failureStr - print((fullOut + Doc.hardLine + Doc.hardLine + Doc.text(excepMessage)).render(80)) + print( + (fullOut + Doc.hardLine + Doc.hardLine + Doc.text(excepMessage)) + .render(80) + ) .as(ExitCode.Error) } case Output.EvaluationResult(_, tpe, resDoc) => val tDoc = rankn.Type.fullyResolvedDocument.document(tpe) - val doc = resDoc.value + (Doc.lineOrEmpty + Doc.text(": ") + tDoc).nested(4) + val doc = + resDoc.value + (Doc.lineOrEmpty + Doc.text(": ") + tDoc).nested(4) print(doc.render(100)).as(ExitCode.Success) case Output.JsonOutput(json, pathOpt) => val jdoc = json.toDoc (pathOpt match { case Some(path) => CodeGenWrite.writeDoc(path, jdoc) - case None => IO(println(jdoc.renderTrim(80))) + case None => IO(println(jdoc.renderTrim(80))) }).as(ExitCode.Success) case Output.TranspileOut(outs, base) => def path(p: List[String]): Path = p.foldLeft(base)(_.resolve(_)) - outs.toList.map { case (p, d) => - (p, CodeGenWrite.writeDoc(path(p.toList), d)) - } - .sortBy(_._1) - .traverse_ { case (_, w) => w } - .as(ExitCode.Success) + outs.toList + .map { case (p, d) => + (p, CodeGenWrite.writeDoc(path(p.toList), d)) + } + .sortBy(_._1) + .traverse_ { case (_, w) => w } + .as(ExitCode.Success) case Output.CompileOut(packList, ifout, output) => val ifres = ifout match { @@ -168,7 +191,8 @@ object PathModule extends MainModule[IO] { import scala.jdk.CollectionConverters._ def getP(p: Path): Option[PackageName] = { - val subPath = p.relativize(packFile) + val subPath = p + .relativize(packFile) .asScala .map { part => part.toString.toLowerCase.capitalize @@ -178,7 +202,7 @@ object PathModule extends MainModule[IO] { val dropExtension = """(.*)\.[^.]*$""".r val toParse = subPath match { case dropExtension(prefix) => prefix - case _ => subPath + case _ => subPath } PackageName.parse(toParse) } @@ -186,9 +210,9 @@ object PathModule extends MainModule[IO] { @annotation.tailrec def loop(roots: List[Path]): Option[PackageName] = roots match { - case Nil => None + case Nil => None case h :: _ if packFile.startsWith(h) => getP(h) - case _ :: t => loop(t) + case _ :: t => loop(t) } if (packFile.toString.isEmpty) None diff --git a/cli/src/main/scala/org/bykn/bosatsu/TypedExprToProto.scala b/cli/src/main/scala/org/bykn/bosatsu/TypedExprToProto.scala index 0bc64af88..adfdc48c5 100644 --- a/cli/src/main/scala/org/bykn/bosatsu/TypedExprToProto.scala +++ b/cli/src/main/scala/org/bykn/bosatsu/TypedExprToProto.scala @@ -6,7 +6,12 @@ import cats.data.{NonEmptyList, ReaderT, StateT} import cats.effect.IO import org.bykn.bosatsu.graph.Memoize import java.nio.file.Path -import java.io.{FileInputStream, FileOutputStream, BufferedInputStream, BufferedOutputStream} +import java.io.{ + FileInputStream, + FileOutputStream, + BufferedInputStream, + BufferedOutputStream +} import org.bykn.bosatsu.rankn.{DefinedType, Type, TypeEnv} import scala.util.{Failure, Success, Try} import scala.reflect.ClassTag @@ -17,9 +22,8 @@ import Identifier.{Bindable, Constructor} import cats.implicits._ -/** - * convert TypedExpr to and from Protobuf representation - */ +/** convert TypedExpr to and from Protobuf representation + */ object ProtoConverter { case class IdAssignment[A1, A2](mapping: Map[A1, Int], inOrder: Vector[A2]) { def get(a1: A1, a2: => A2): Either[(IdAssignment[A1, A2], Int), Int] = @@ -27,9 +31,8 @@ object ProtoConverter { case Some(id) => Right(id) case None => val id = inOrder.size - val next = copy( - mapping = mapping.updated(a1, id), - inOrder = inOrder :+ a2) + val next = + copy(mapping = mapping.updated(a1, id), inOrder = inOrder :+ a2) Left((next, id)) } @@ -38,25 +41,42 @@ object ProtoConverter { } object IdAssignment { - def empty[A1, A2]: IdAssignment[A1, A2] = IdAssignment(Map.empty, Vector.empty) + def empty[A1, A2]: IdAssignment[A1, A2] = + IdAssignment(Map.empty, Vector.empty) } case class SerState( - strings: IdAssignment[String, String], - types: IdAssignment[Type, proto.Type], - patterns: IdAssignment[Pattern[(PackageName, Constructor), Type], proto.Pattern], - expressions: IdAssignment[TypedExpr[Any], proto.TypedExpr]) { + strings: IdAssignment[String, String], + types: IdAssignment[Type, proto.Type], + patterns: IdAssignment[ + Pattern[(PackageName, Constructor), Type], + proto.Pattern + ], + expressions: IdAssignment[TypedExpr[Any], proto.TypedExpr] + ) { def stringId(s: String): Either[(SerState, Int), Int] = - strings.get(s, s).left.map { case (next, id) => (copy(strings = next), id) } + strings.get(s, s).left.map { case (next, id) => + (copy(strings = next), id) + } - def typeId(t: Type, protoType: => proto.Type): Either[(SerState, Int), Int] = - types.get(t, protoType).left.map { case (next, id) => (copy(types = next), id) } + def typeId( + t: Type, + protoType: => proto.Type + ): Either[(SerState, Int), Int] = + types.get(t, protoType).left.map { case (next, id) => + (copy(types = next), id) + } } object SerState { val empty: SerState = - SerState(IdAssignment.empty, IdAssignment.empty, IdAssignment.empty, IdAssignment.empty) + SerState( + IdAssignment.empty, + IdAssignment.empty, + IdAssignment.empty, + IdAssignment.empty + ) } type Tab[A] = StateT[Try, SerState, A] @@ -73,7 +93,7 @@ object ProtoConverter { case Failure(_) => System.err.println(message) self - } + } } private def tabFail[S, A](ex: Exception): Tab[A] = @@ -82,13 +102,14 @@ object ProtoConverter { Monad[Tab].pure(a) private def get(fn: SerState => Either[(SerState, Int), Int]): Tab[Int] = - StateT.get[Try, SerState] + StateT + .get[Try, SerState] .flatMap { ss => fn(ss) match { - case Right(idx) => StateT.pure(idx + 1) - case Left((ss, idx)) => - StateT.set[Try, SerState](ss).as(idx + 1) - } + case Right(idx) => StateT.pure(idx + 1) + case Left((ss, idx)) => + StateT.set[Try, SerState](ss).as(idx + 1) + } } private def getId(s: String): Tab[Int] = get(_.stringId(s)) @@ -97,11 +118,16 @@ object ProtoConverter { get(_.typeId(t, pt)) private def getProtoTypeTab(t: Type): Tab[Option[Int]] = - StateT.get[Try, SerState] + StateT + .get[Try, SerState] .map(_.types.indexOf(t).map(_ + 1)) - private def writePattern(p: Pattern[(PackageName, Constructor), Type], pp: proto.Pattern): Tab[Int] = - StateT.get[Try, SerState] + private def writePattern( + p: Pattern[(PackageName, Constructor), Type], + pp: proto.Pattern + ): Tab[Int] = + StateT + .get[Try, SerState] .flatMap { s => s.patterns.get(p, pp) match { case Right(_) => @@ -113,7 +139,8 @@ object ProtoConverter { } private def writeExpr(te: TypedExpr[Any], pte: proto.TypedExpr): Tab[Int] = - StateT.get[Try, SerState] + StateT + .get[Try, SerState] .flatMap { s => s.expressions.get(te, pte) match { case Right(_) => @@ -128,11 +155,12 @@ object ProtoConverter { t.run(SerState.empty) class DecodeState private ( - strings: Array[String], - types: Array[Type], - dts: Array[DefinedType[Kind.Arg]], - patterns: Array[Pattern[(PackageName, Constructor), Type]], - expr: Array[TypedExpr[Unit]]) { + strings: Array[String], + types: Array[Type], + dts: Array[DefinedType[Kind.Arg]], + patterns: Array[Pattern[(PackageName, Constructor), Type]], + expr: Array[TypedExpr[Unit]] + ) { def getString(idx: Int): Option[String] = if ((0 <= idx) && (idx < strings.length)) Some(strings(idx)) else None @@ -149,7 +177,10 @@ object ProtoConverter { if ((0 <= idx) && (idx < types.length)) Success(types(idx)) else Failure(new Exception(msg)) - def tryPattern(idx: Int, msg: => String): Try[Pattern[(PackageName, Constructor), Type]] = + def tryPattern( + idx: Int, + msg: => String + ): Try[Pattern[(PackageName, Constructor), Type]] = if ((0 <= idx) && (idx < patterns.length)) Success(patterns(idx)) else Failure(new Exception(msg)) @@ -170,7 +201,9 @@ object ProtoConverter { def withTypes(ary: Array[Type]): DecodeState = new DecodeState(strings, ary, dts, patterns, expr) - def withPatterns(ary: Array[Pattern[(PackageName, Constructor), Type]]): DecodeState = + def withPatterns( + ary: Array[Pattern[(PackageName, Constructor), Type]] + ): DecodeState = new DecodeState(strings, types, dts, ary, expr) def withExprs(ary: Array[TypedExpr[Unit]]): DecodeState = @@ -179,12 +212,20 @@ object ProtoConverter { object DecodeState { def init(strings: Seq[String]): DecodeState = - new DecodeState(strings.toArray, Array.empty, Array.empty, Array.empty, Array.empty) + new DecodeState( + strings.toArray, + Array.empty, + Array.empty, + Array.empty, + Array.empty + ) } type DTab[A] = ReaderT[Try, DecodeState, A] - private def find[A](idx: Int, context: => String)(fn: (DecodeState, Int) => Option[A]): DTab[A] = + private def find[A](idx: Int, context: => String)( + fn: (DecodeState, Int) => Option[A] + ): DTab[A] = ReaderT { decodeState => fn(decodeState, idx - 1) match { case Some(s) => Success(s) @@ -198,29 +239,37 @@ object ProtoConverter { private def lookupType(idx: Int, context: => String): DTab[Type] = find(idx, context)(_.getType(_)) - private def lookupDts(idx: Int, context: => String): DTab[DefinedType[Kind.Arg]] = + private def lookupDts( + idx: Int, + context: => String + ): DTab[DefinedType[Kind.Arg]] = find(idx, context)(_.getDt(_)) private def lookupExpr(idx: Int, context: => String): DTab[TypedExpr[Unit]] = find(idx, context)(_.getExpr(_)) - /** - * this is code to build tables of serialized dags. We use this for types, patterns, expressions - */ - private def buildTable[A, B: ClassTag](ary: Array[A])(fn: (A, Int => Try[B]) => Try[B]): Try[Array[B]] = { + /** this is code to build tables of serialized dags. We use this for types, + * patterns, expressions + */ + private def buildTable[A, B: ClassTag]( + ary: Array[A] + )(fn: (A, Int => Try[B]) => Try[B]): Try[Array[B]] = { val result = new Array[B](ary.length) def lookup(a: A, max: Int): Int => Try[B] = { idx => if (idx > 0 && idx <= max) Success(result(idx - 1)) - else Failure(new Exception(s"while decoding $a, invalid index $idx, max: $max")) + else + Failure( + new Exception(s"while decoding $a, invalid index $idx, max: $max") + ) } var idx = 0 var res: Failure[Array[B]] = null - while((idx < ary.length) && (res eq null)) { + while ((idx < ary.length) && (res eq null)) { val a = ary(idx) val lookupFn = lookup(a, idx) fn(a, lookupFn) match { - case Success(b) => result(idx) = b + case Success(b) => result(idx) = b case Failure(err) => res = Failure(err) } idx = idx + 1 @@ -231,7 +280,6 @@ object ProtoConverter { def buildTypes(types: Seq[proto.Type]): DTab[Array[Type]] = ReaderT[Try, DecodeState, Array[Type]] { ds => - def typeFromProto(p: proto.Type, tpe: Int => Try[Type]): Try[Type] = { import proto.Type.Value import bosatsu.TypedAst.{Type => _, _} @@ -269,10 +317,18 @@ object ProtoConverter { buildTable(types.toArray)(typeFromProto _) } - def buildPatterns(pats: Seq[proto.Pattern]): DTab[Array[Pattern[(PackageName, Constructor), Type]]] = - ReaderT[Try, DecodeState, Array[Pattern[(PackageName, Constructor), Type]]] { ds => - - def patternFromProto(p: proto.Pattern, pat: Int => Try[Pattern[(PackageName, Constructor), Type]]): Try[Pattern[(PackageName, Constructor), Type]] = { + def buildPatterns( + pats: Seq[proto.Pattern] + ): DTab[Array[Pattern[(PackageName, Constructor), Type]]] = + ReaderT[ + Try, + DecodeState, + Array[Pattern[(PackageName, Constructor), Type]] + ] { ds => + def patternFromProto( + p: proto.Pattern, + pat: Int => Try[Pattern[(PackageName, Constructor), Type]] + ): Try[Pattern[(PackageName, Constructor), Type]] = { import proto.Pattern.Value def str(i: Int): Try[String] = @@ -290,22 +346,32 @@ object ProtoConverter { case Value.NamedPat(proto.NamedPat(nidx, pidx, _)) => (bindable(nidx), pat(pidx)).mapN(Pattern.Named(_, _)) case Value.ListPat(proto.ListPat(lp, _)) => - def decodePart(part: proto.ListPart): Try[Pattern.ListPart[Pattern[(PackageName, Constructor), Type]]] = + def decodePart(part: proto.ListPart): Try[ + Pattern.ListPart[Pattern[(PackageName, Constructor), Type]] + ] = part.value match { - case proto.ListPart.Value.Empty => Failure(new Exception(s"invalid empty list pattern in $p")) - case proto.ListPart.Value.ItemPattern(p) => pat(p).map(Pattern.ListPart.Item(_)) - case proto.ListPart.Value.UnnamedList(_) => Success(Pattern.ListPart.WildList) - case proto.ListPart.Value.NamedList(idx) => bindable(idx).map { n => Pattern.ListPart.NamedList(n) } + case proto.ListPart.Value.Empty => + Failure(new Exception(s"invalid empty list pattern in $p")) + case proto.ListPart.Value.ItemPattern(p) => + pat(p).map(Pattern.ListPart.Item(_)) + case proto.ListPart.Value.UnnamedList(_) => + Success(Pattern.ListPart.WildList) + case proto.ListPart.Value.NamedList(idx) => + bindable(idx).map { n => Pattern.ListPart.NamedList(n) } } lp.toList.traverse(decodePart).map(Pattern.ListPat(_)) case Value.StrPat(proto.StrPat(items, _)) => def decodePart(part: proto.StrPart): Try[Pattern.StrPart] = part.value match { - case proto.StrPart.Value.Empty => Failure(new Exception(s"invalid empty list pattern in $p")) - case proto.StrPart.Value.LiteralStr(idx) => str(idx).map(Pattern.StrPart.LitStr(_)) - case proto.StrPart.Value.UnnamedStr(_) => Success(Pattern.StrPart.WildStr) - case proto.StrPart.Value.NamedStr(idx) => bindable(idx).map { n => Pattern.StrPart.NamedStr(n) } + case proto.StrPart.Value.Empty => + Failure(new Exception(s"invalid empty list pattern in $p")) + case proto.StrPart.Value.LiteralStr(idx) => + str(idx).map(Pattern.StrPart.LitStr(_)) + case proto.StrPart.Value.UnnamedStr(_) => + Success(Pattern.StrPart.WildStr) + case proto.StrPart.Value.NamedStr(idx) => + bindable(idx).map { n => Pattern.StrPart.NamedStr(n) } } items.toList match { @@ -318,12 +384,17 @@ object ProtoConverter { .map(Pattern.StrPat(_)) } case Value.AnnotationPat(proto.AnnotationPat(pidx, tidx, _)) => - (pat(pidx), ds.tryType(tidx - 1, s"invalid type index $tidx in: $p")) + ( + pat(pidx), + ds.tryType(tidx - 1, s"invalid type index $tidx in: $p") + ) .mapN(Pattern.Annotation(_, _)) case Value.StructPat(proto.StructPattern(packIdx, cidx, args, _)) => str(packIdx) .product(str(cidx)) - .flatMap { case (p, c) => fullNameFromStr(p, c, s"invalid structpat names: $p, $c") } + .flatMap { case (p, c) => + fullNameFromStr(p, c, s"invalid structpat names: $p, $c") + } .flatMap { pc => args.toList.traverse(pat).map(Pattern.PositionalStruct(pc, _)) } @@ -336,7 +407,11 @@ object ProtoConverter { } case notTwo => - Failure(new Exception(s"invalid union found size: ${notTwo.size}, expected 2 or more")) + Failure( + new Exception( + s"invalid union found size: ${notTwo.size}, expected 2 or more" + ) + ) } } } @@ -344,16 +419,23 @@ object ProtoConverter { buildTable(pats.toArray)(patternFromProto _) } - def recursionKindFromProto(rec: proto.RecursionKind, context: => String): Try[RecursionKind] = + def recursionKindFromProto( + rec: proto.RecursionKind, + context: => String + ): Try[RecursionKind] = rec match { case proto.RecursionKind.NotRec => Success(RecursionKind.NonRecursive) - case proto.RecursionKind.IsRec => Success(RecursionKind.Recursive) - case other => Failure(new Exception(s"invalid recursion kind: $other, in $context")) + case proto.RecursionKind.IsRec => Success(RecursionKind.Recursive) + case other => + Failure(new Exception(s"invalid recursion kind: $other, in $context")) } def buildExprs(exprs: Seq[proto.TypedExpr]): DTab[Array[TypedExpr[Unit]]] = ReaderT[Try, DecodeState, Array[TypedExpr[Unit]]] { ds => - def expressionFromProto(ex: proto.TypedExpr, exprOf: Int => Try[TypedExpr[Unit]]): Try[TypedExpr[Unit]] = { + def expressionFromProto( + ex: proto.TypedExpr, + exprOf: Int => Try[TypedExpr[Unit]] + ): Try[TypedExpr[Unit]] = { import proto.TypedExpr.Value def str(i: Int): Try[String] = @@ -370,50 +452,69 @@ object ProtoConverter { ex.value match { case Value.Empty => Failure(new Exception("invalid empty TypedExpr")) - case ge @ Value.GenericExpr(proto.GenericExpr(typeParams, kinds, expr, _)) => + case ge @ Value.GenericExpr( + proto.GenericExpr(typeParams, kinds, expr, _) + ) => if (typeParams.length != kinds.length) Failure(new Exception(s"bound and kinds length mismatch in $ge")) else NonEmptyList.fromList(typeParams.toList) match { case Some(nel) => - (nel.traverse(str), kinds.traverse { k => kindFromProto(Some(k)) }, exprOf(expr)) + ( + nel.traverse(str), + kinds.traverse { k => kindFromProto(Some(k)) }, + exprOf(expr) + ) .mapN { (strs, kindsSeq, expr) => // we know the length is the same as the params NEL val kinds = NonEmptyList.fromListUnsafe(kindsSeq.toList) val bs = strs.map(Type.Var.Bound(_)) TypedExpr.Generic(bs.zip(kinds), expr) } - case None => Failure(new Exception(s"invalid empty type params in generic($ge): $ex")) + case None => + Failure( + new Exception( + s"invalid empty type params in generic($ge): $ex" + ) + ) } case Value.AnnotationExpr(proto.AnnotationExpr(expr, tpe, _)) => (exprOf(expr), typeOf(tpe)) .mapN(TypedExpr.Annotation(_, _)) case Value.LambdaExpr(proto.LambdaExpr(varsName, varsTpe, expr, _)) => - (varsName.traverse(bindable(_)), varsTpe.traverse(typeOf(_)), exprOf(expr)) + ( + varsName.traverse(bindable(_)), + varsTpe.traverse(typeOf(_)), + exprOf(expr) + ) .flatMapN { (vs, ts, e) => val vsLen = vs.length if (vsLen <= 0) { Failure(new Exception(s"no bind names in this lambda: $ex")) - } - else if (vsLen == ts.length) { + } else if (vsLen == ts.length) { // we know length > 0 and they match - val args = NonEmptyList.fromListUnsafe(vs.iterator.zip(ts.iterator).toList) + val args = NonEmptyList.fromListUnsafe( + vs.iterator.zip(ts.iterator).toList + ) Success(TypedExpr.AnnotatedLambda(args, e, ())) - } - else { - Failure(new Exception(s"type list length didn't match bind name length in $ex")) + } else { + Failure( + new Exception( + s"type list length didn't match bind name length in $ex" + ) + ) } } case Value.VarExpr(proto.VarExpr(pack, varname, tpe, _)) => val tryPack = if (pack == 0) Success(None) - else for { - ps <- str(pack) - pack <- parsePack(ps, s"expression: $ex") - } yield Some(pack) + else + for { + ps <- str(pack) + pack <- parsePack(ps, s"expression: $ex") + } yield Some(pack) - (tryPack, typeOf(tpe)) - .tupled + (tryPack, typeOf(tpe)).tupled .flatMap { case (None, tpe) => bindable(varname).map(TypedExpr.Local(_, tpe, ())) @@ -436,14 +537,20 @@ object ProtoConverter { .mapN(TypedExpr.Let(_, _, _, _, ())) case Value.LiteralExpr(proto.LiteralExpr(lit, tpe, _)) => lit match { - case None => Failure(new Exception(s"invalid missing literal in $ex")) + case None => + Failure(new Exception(s"invalid missing literal in $ex")) case Some(lit) => (litFromProto(lit), typeOf(tpe)) .mapN(TypedExpr.Literal(_, _, ())) } case Value.MatchExpr(proto.MatchExpr(argId, branches, _)) => - def buildBranch(b: proto.Branch): Try[(Pattern[(PackageName, Constructor), Type], TypedExpr[Unit])] = - (ds.tryPattern(b.pattern - 1, s"invalid pattern in $ex"), exprOf(b.resultExpr)).tupled + def buildBranch(b: proto.Branch): Try[ + (Pattern[(PackageName, Constructor), Type], TypedExpr[Unit]) + ] = + ( + ds.tryPattern(b.pattern - 1, s"invalid pattern in $ex"), + exprOf(b.resultExpr) + ).tupled NonEmptyList.fromList(branches.toList) match { case Some(nel) => @@ -465,11 +572,21 @@ object ProtoConverter { case Some(pack) => Success(pack) } - private def fullNameFromStr(pstr: String, tstr: String, context: => String): Try[(PackageName, Constructor)] = + private def fullNameFromStr( + pstr: String, + tstr: String, + context: => String + ): Try[(PackageName, Constructor)] = (parsePack(pstr, context), toConstructor(tstr)).tupled - def typeConstFromStr(pstr: String, tstr: String, context: => String): Try[Type.Const.Defined] = - fullNameFromStr(pstr, tstr, context).map { case (p, c) => Type.Const.Defined(p, TypeName(c)) } + def typeConstFromStr( + pstr: String, + tstr: String, + context: => String + ): Try[Type.Const.Defined] = + fullNameFromStr(pstr, tstr, context).map { case (p, c) => + Type.Const.Defined(p, TypeName(c)) + } def typeConstFromProto(p: proto.TypeConst): DTab[Type.Const.Defined] = { val proto.TypeConst(packidx, tidx, _) = p @@ -508,7 +625,11 @@ object ProtoConverter { .traverse { case (b, _) => getId(b.name) } .flatMap { ids => lazy val ks = bs.map { case (_, k) => kindToProto(k) } - getTypeId(p, proto.Type(Value.TypeForAll(TypeForAll(ids, ks.toList, idx)))) + getTypeId( + p, + proto + .Type(Value.TypeForAll(TypeForAll(ids, ks.toList, idx))) + ) } } case Type.TyApply(on, arg) => @@ -537,8 +658,7 @@ object ProtoConverter { case Lit.Integer(i) => try { proto.Literal.Value.IntValueAs64(i.longValueExact) - } - catch { + } catch { case _: ArithmeticException => proto.Literal.Value.IntValueAsString(i.toString) } @@ -561,78 +681,122 @@ object ProtoConverter { } def patternToProto(p: Pattern[(PackageName, Constructor), Type]): Tab[Int] = - StateT.get[Try, SerState] + StateT + .get[Try, SerState] .map(_.patterns.indexOf(p)) .flatMap { case Some(idx) => tabPure(idx + 1) case None => p match { case Pattern.WildCard => - writePattern(p, proto.Pattern(proto.Pattern.Value.WildPat(proto.WildCardPat()))) + writePattern( + p, + proto.Pattern(proto.Pattern.Value.WildPat(proto.WildCardPat())) + ) case Pattern.Literal(lit) => val litP = litToProto(lit) writePattern(p, proto.Pattern(proto.Pattern.Value.LitPat(litP))) case Pattern.Var(n) => getId(n.sourceCodeRepr) .flatMap { idx => - writePattern(p, proto.Pattern(proto.Pattern.Value.VarNamePat(idx))) + writePattern( + p, + proto.Pattern(proto.Pattern.Value.VarNamePat(idx)) + ) } - case named@Pattern.Named(n, p) => + case named @ Pattern.Named(n, p) => getId(n.sourceCodeRepr) .product(patternToProto(p)) .flatMap { case (idx, pidx) => - writePattern(named, proto.Pattern(proto.Pattern.Value.NamedPat(proto.NamedPat(idx, pidx)))) + writePattern( + named, + proto.Pattern( + proto.Pattern.Value.NamedPat(proto.NamedPat(idx, pidx)) + ) + ) } case Pattern.StrPat(parts) => - parts.traverse { - case Pattern.StrPart.WildStr => - tabPure(proto.StrPart(proto.StrPart.Value.UnnamedStr(proto.WildCardPat()))) - case Pattern.StrPart.NamedStr(n) => - getId(n.sourceCodeRepr).map { idx => - proto.StrPart(proto.StrPart.Value.NamedStr(idx)) - } - case Pattern.StrPart.LitStr(s) => - getId(s).map { idx => - proto.StrPart(proto.StrPart.Value.LiteralStr(idx)) - } - } - .flatMap { parts => - writePattern(p, proto.Pattern(proto.Pattern.Value.StrPat(proto.StrPat(parts.toList)))) - } + parts + .traverse { + case Pattern.StrPart.WildStr => + tabPure( + proto.StrPart( + proto.StrPart.Value.UnnamedStr(proto.WildCardPat()) + ) + ) + case Pattern.StrPart.NamedStr(n) => + getId(n.sourceCodeRepr).map { idx => + proto.StrPart(proto.StrPart.Value.NamedStr(idx)) + } + case Pattern.StrPart.LitStr(s) => + getId(s).map { idx => + proto.StrPart(proto.StrPart.Value.LiteralStr(idx)) + } + } + .flatMap { parts => + writePattern( + p, + proto.Pattern( + proto.Pattern.Value.StrPat(proto.StrPat(parts.toList)) + ) + ) + } case Pattern.ListPat(items) => - items.traverse { - case Pattern.ListPart.Item(itemPat) => - patternToProto(itemPat).map { pidx => - proto.ListPart(proto.ListPart.Value.ItemPattern(pidx)) - } - case Pattern.ListPart.WildList => - tabPure(proto.ListPart(proto.ListPart.Value.UnnamedList(proto.WildCardPat()))) - case Pattern.ListPart.NamedList(bindable) => - getId(bindable.sourceCodeRepr).map { idx => - proto.ListPart(proto.ListPart.Value.NamedList(idx)) - } - } - .flatMap { parts => - writePattern(p, proto.Pattern(proto.Pattern.Value.ListPat(proto.ListPat(parts)))) - } - case ann@Pattern.Annotation(p, tpe) => + items + .traverse { + case Pattern.ListPart.Item(itemPat) => + patternToProto(itemPat).map { pidx => + proto.ListPart(proto.ListPart.Value.ItemPattern(pidx)) + } + case Pattern.ListPart.WildList => + tabPure( + proto.ListPart( + proto.ListPart.Value.UnnamedList(proto.WildCardPat()) + ) + ) + case Pattern.ListPart.NamedList(bindable) => + getId(bindable.sourceCodeRepr).map { idx => + proto.ListPart(proto.ListPart.Value.NamedList(idx)) + } + } + .flatMap { parts => + writePattern( + p, + proto.Pattern( + proto.Pattern.Value.ListPat(proto.ListPat(parts)) + ) + ) + } + case ann @ Pattern.Annotation(p, tpe) => patternToProto(p) .product(typeToProto(tpe)) .flatMap { case (pidx, tidx) => - writePattern(ann, proto.Pattern(proto.Pattern.Value.AnnotationPat(proto.AnnotationPat(pidx, tidx)))) + writePattern( + ann, + proto.Pattern( + proto.Pattern.Value + .AnnotationPat(proto.AnnotationPat(pidx, tidx)) + ) + ) } - case pos@Pattern.PositionalStruct((packName, consName), params) => + case pos @ Pattern.PositionalStruct((packName, consName), params) => typeConstToProto(Type.Const.Defined(packName, TypeName(consName))) .flatMap { ptc => params .traverse(patternToProto) .flatMap { parts => - writePattern(pos, - proto.Pattern(proto.Pattern.Value.StructPat( - proto.StructPattern( - packageName = ptc.packageName, - constructorName = ptc.typeName, - params = parts)))) + writePattern( + pos, + proto.Pattern( + proto.Pattern.Value.StructPat( + proto.StructPattern( + packageName = ptc.packageName, + constructorName = ptc.typeName, + params = parts + ) + ) + ) + ) } } @@ -640,70 +804,98 @@ object ProtoConverter { (h :: t.toList) .traverse(patternToProto) .flatMap { us => - writePattern(p, proto.Pattern(proto.Pattern.Value.UnionPat(proto.UnionPattern(us)))) + writePattern( + p, + proto.Pattern( + proto.Pattern.Value.UnionPat(proto.UnionPattern(us)) + ) + ) } } } def typedExprToProto(te: TypedExpr[Any]): Tab[Int] = - StateT.get[Try, SerState] + StateT + .get[Try, SerState] .map(_.expressions.indexOf(te)) .flatMap { case Some(idx) => tabPure(idx + 1) case None => import TypedExpr._ te match { - case g@Generic(tvars, expr) => - tvars.toList.traverse { case (v, _) => getId(v.name) } + case g @ Generic(tvars, expr) => + tvars.toList + .traverse { case (v, _) => getId(v.name) } .product(typedExprToProto(expr)) .flatMap { case (tparams, exid) => val ks = tvars.map { case (_, k) => kindToProto(k) } val ex = proto.GenericExpr(tparams, ks.toList, exid) - writeExpr(g, proto.TypedExpr(proto.TypedExpr.Value.GenericExpr(ex))) + writeExpr( + g, + proto.TypedExpr(proto.TypedExpr.Value.GenericExpr(ex)) + ) } - case a@Annotation(term, tpe) => + case a @ Annotation(term, tpe) => typedExprToProto(term) .product(typeToProto(tpe)) .flatMap { case (term, tpe) => val ex = proto.AnnotationExpr(term, tpe) - writeExpr(a, proto.TypedExpr(proto.TypedExpr.Value.AnnotationExpr(ex))) + writeExpr( + a, + proto.TypedExpr(proto.TypedExpr.Value.AnnotationExpr(ex)) + ) } - case al@AnnotatedLambda(args, res, _) => - args.toList.traverse { case (n, tpe) => - getId(n.sourceCodeRepr).product(typeToProto(tpe)) - } - .product(typedExprToProto(res)) - .flatMap { case (args, resid) => - val ex = proto.LambdaExpr(args.map(_._1), args.map(_._2), resid) - writeExpr(al, proto.TypedExpr(proto.TypedExpr.Value.LambdaExpr(ex))) - } - case l@Local(nm, tpe, _) => + case al @ AnnotatedLambda(args, res, _) => + args.toList + .traverse { case (n, tpe) => + getId(n.sourceCodeRepr).product(typeToProto(tpe)) + } + .product(typedExprToProto(res)) + .flatMap { case (args, resid) => + val ex = + proto.LambdaExpr(args.map(_._1), args.map(_._2), resid) + writeExpr( + al, + proto.TypedExpr(proto.TypedExpr.Value.LambdaExpr(ex)) + ) + } + case l @ Local(nm, tpe, _) => getId(nm.sourceCodeRepr) .product(typeToProto(tpe)) .flatMap { case (varId, tpeId) => val ex = proto.VarExpr(0, varId, tpeId) - writeExpr(l, proto.TypedExpr(proto.TypedExpr.Value.VarExpr(ex))) + writeExpr( + l, + proto.TypedExpr(proto.TypedExpr.Value.VarExpr(ex)) + ) } - case g@Global(pack, nm, tpe, _) => - (getId(pack.asString), + case g @ Global(pack, nm, tpe, _) => + ( + getId(pack.asString), getId(nm.sourceCodeRepr), - typeToProto(tpe)) - .tupled + typeToProto(tpe) + ).tupled .flatMap { case (packId, varId, tpeId) => val ex = proto.VarExpr(packId, varId, tpeId) - writeExpr(g, proto.TypedExpr(proto.TypedExpr.Value.VarExpr(ex))) + writeExpr( + g, + proto.TypedExpr(proto.TypedExpr.Value.VarExpr(ex)) + ) } - case a@App(fn, args, resTpe, _) => + case a @ App(fn, args, resTpe, _) => typedExprToProto(fn) .product(args.traverse(typedExprToProto(_))) .product(typeToProto(resTpe)) .flatMap { case ((fn, args), resTpe) => val ex = proto.AppExpr(fn, args.toList, resTpe) - writeExpr(a, proto.TypedExpr(proto.TypedExpr.Value.AppExpr(ex))) + writeExpr( + a, + proto.TypedExpr(proto.TypedExpr.Value.AppExpr(ex)) + ) } - case let@Let(nm, nmexpr, inexpr, rec, _) => + case let @ Let(nm, nmexpr, inexpr, rec, _) => val prec = rec match { - case RecursionKind.Recursive => proto.RecursionKind.IsRec + case RecursionKind.Recursive => proto.RecursionKind.IsRec case RecursionKind.NonRecursive => proto.RecursionKind.NotRec } getId(nm.sourceCodeRepr) @@ -711,16 +903,24 @@ object ProtoConverter { .product(typedExprToProto(inexpr)) .flatMap { case ((nm, nmexpr), inexpr) => val ex = proto.LetExpr(nm, nmexpr, inexpr, prec) - writeExpr(let, proto.TypedExpr(proto.TypedExpr.Value.LetExpr(ex))) + writeExpr( + let, + proto.TypedExpr(proto.TypedExpr.Value.LetExpr(ex)) + ) } - case lit@Literal(l, tpe, _) => + case lit @ Literal(l, tpe, _) => typeToProto(tpe) .flatMap { tpe => val ex = proto.LiteralExpr(Some(litToProto(l)), tpe) - writeExpr(lit, proto.TypedExpr(proto.TypedExpr.Value.LiteralExpr(ex))) + writeExpr( + lit, + proto.TypedExpr(proto.TypedExpr.Value.LiteralExpr(ex)) + ) } - case m@Match(argE, branches, _) => - def encodeBranch(p: (Pattern[(PackageName, Constructor), Type], TypedExpr[Any])): Tab[proto.Branch] = + case m @ Match(argE, branches, _) => + def encodeBranch( + p: (Pattern[(PackageName, Constructor), Type], TypedExpr[Any]) + ): Tab[proto.Branch] = (patternToProto(p._1), typedExprToProto(p._2)) .mapN { (pat, expr) => proto.Branch(pat, expr) } @@ -728,7 +928,10 @@ object ProtoConverter { .product(branches.toList.traverse(encodeBranch)) .flatMap { case (argId, branches) => val ex = proto.MatchExpr(argId, branches) - writeExpr(m, proto.TypedExpr(proto.TypedExpr.Value.MatchExpr(ex))) + writeExpr( + m, + proto.TypedExpr(proto.TypedExpr.Value.MatchExpr(ex)) + ) } } } @@ -736,19 +939,20 @@ object ProtoConverter { def varianceToProto(v: Variance): proto.Variance = v match { - case Variance.Phantom => proto.Variance.Phantom - case Variance.Covariant => proto.Variance.Covariant + case Variance.Phantom => proto.Variance.Phantom + case Variance.Covariant => proto.Variance.Covariant case Variance.Contravariant => proto.Variance.Contravariant - case Variance.Invariant => proto.Variance.Invariant + case Variance.Invariant => proto.Variance.Invariant } - + def varianceFromProto(p: proto.Variance): Try[Variance] = p match { - case proto.Variance.Phantom => Success(Variance.Phantom) - case proto.Variance.Covariant => Success(Variance.Covariant) + case proto.Variance.Phantom => Success(Variance.Phantom) + case proto.Variance.Covariant => Success(Variance.Covariant) case proto.Variance.Contravariant => Success(Variance.Contravariant) - case proto.Variance.Invariant => Success(Variance.Invariant) - case proto.Variance.Unrecognized(value) => Failure(new Exception(s"unrecognized value for variance: $value")) + case proto.Variance.Invariant => Success(Variance.Invariant) + case proto.Variance.Unrecognized(value) => + Failure(new Exception(s"unrecognized value for variance: $value")) } def kindToProto(kind: Kind): proto.Kind = @@ -758,14 +962,19 @@ object ProtoConverter { val vp = varianceToProto(v) val ip = kindToProto(i) val op = kindToProto(o) - proto.Kind(proto.Kind.Value.Cons(proto.ConsKind(vp, Some(ip), Some(op)))) + proto.Kind( + proto.Kind.Value.Cons(proto.ConsKind(vp, Some(ip), Some(op))) + ) } def kindFromProto(kp: Option[proto.Kind]): Try[Kind] = kp match { case None | Some(proto.Kind(proto.Kind.Value.Empty, _)) => Failure(new Exception("missing Kind")) - case Some(proto.Kind(proto.Kind.Value.Type(proto.TypeKind(_)), _)) => Success(Kind.Type) - case Some(proto.Kind(proto.Kind.Value.Cons(proto.ConsKind(v, i, o, _)), _)) => + case Some(proto.Kind(proto.Kind.Value.Type(proto.TypeKind(_)), _)) => + Success(Kind.Type) + case Some( + proto.Kind(proto.Kind.Value.Cons(proto.ConsKind(v, i, o, _)), _) + ) => for { variance <- varianceFromProto(v) kindI <- kindFromProto(i) @@ -779,7 +988,11 @@ object ProtoConverter { typeVarBoundToProto(tv._1) .map { tvb => val Kind.Arg(variance, kind) = tv._2 - proto.TypeParam(Some(tvb), varianceToProto(variance), Some(kindToProto(kind))) + proto.TypeParam( + Some(tvb), + varianceToProto(variance), + Some(kindToProto(kind)) + ) } val protoTypeParams: Tab[List[proto.TypeParam]] = @@ -787,30 +1000,36 @@ object ProtoConverter { val constructors: Tab[List[proto.ConstructorFn]] = d.constructors.traverse { cf => - cf.args.traverse { case (b, t) => - typeToProto(t).flatMap { tidx => - getId(b.sourceCodeRepr) - .map { n => - proto.FnParam(n, tidx) + cf.args + .traverse { case (b, t) => + typeToProto(t).flatMap { tidx => + getId(b.sourceCodeRepr) + .map { n => + proto.FnParam(n, tidx) + } + } + } + .flatMap { params => + getId(cf.name.asString) + .map { id => + proto.ConstructorFn(id, params) } } - } - .flatMap { params => - getId(cf.name.asString) - .map { id => - proto.ConstructorFn(id, params) - } - } } (protoTypeParams, constructors) .mapN(proto.DefinedType(Some(tc), _, _)) } - def definedTypeFromProto(pdt: proto.DefinedType): DTab[DefinedType[Kind.Arg]] = { + def definedTypeFromProto( + pdt: proto.DefinedType + ): DTab[DefinedType[Kind.Arg]] = { def paramFromProto(tp: proto.TypeParam): DTab[(Type.Var.Bound, Kind.Arg)] = tp.typeVar match { - case None => ReaderT.liftF(Failure(new Exception(s"expected type variable in $tp"))) + case None => + ReaderT.liftF( + Failure(new Exception(s"expected type variable in $tp")) + ) case Some(tv) => val ka = for { v <- varianceFromProto(tp.variance) @@ -831,10 +1050,12 @@ object ProtoConverter { def consFromProto(c: proto.ConstructorFn): DTab[rankn.ConstructorFn] = lookup(c.name, c.toString) .flatMap { cname => - ReaderT.liftF(toConstructor(cname)) + ReaderT + .liftF(toConstructor(cname)) .flatMap { cname => - //def - c.params.toList.traverse(fnParamFromProto) + // def + c.params.toList + .traverse(fnParamFromProto) .map { fnParams => rankn.ConstructorFn(cname, fnParams) } @@ -842,7 +1063,8 @@ object ProtoConverter { } pdt.typeConst match { - case None => ReaderT.liftF(Failure(new Exception(s"missing typeConst: $pdt"))) + case None => + ReaderT.liftF(Failure(new Exception(s"missing typeConst: $pdt"))) case Some(tc) => for { tconst <- typeConstFromProto(tc) @@ -852,7 +1074,10 @@ object ProtoConverter { } } - def referantToProto[V](allDts: Map[(PackageName, TypeName), (DefinedType[Any], Int)], r: Referant[V]): Tab[proto.Referant] = + def referantToProto[V]( + allDts: Map[(PackageName, TypeName), (DefinedType[Any], Int)], + r: Referant[V] + ): Tab[proto.Referant] = r match { case Referant.Value(t) => typeToProto(t).map { tpeId => @@ -863,16 +1088,26 @@ object ProtoConverter { allDts.get(key) match { case Some((_, idx)) => tabPure( - proto.Referant(proto.Referant.Referant.DefinedType( - proto.DefinedTypeReference( - proto.DefinedTypeReference.Value.LocalDefinedTypePtr(idx + 1)))) + proto.Referant( + proto.Referant.Referant.DefinedType( + proto.DefinedTypeReference( + proto.DefinedTypeReference.Value.LocalDefinedTypePtr( + idx + 1 + ) + ) + ) + ) ) case None => // this is a non-local defined type: typeConstToProto(dt.toTypeConst).map { case tc => - proto.Referant(proto.Referant.Referant.DefinedType( - proto.DefinedTypeReference( - proto.DefinedTypeReference.Value.ImportedDefinedType(tc)))) + proto.Referant( + proto.Referant.Referant.DefinedType( + proto.DefinedTypeReference( + proto.DefinedTypeReference.Value.ImportedDefinedType(tc) + ) + ) + ) } } case Referant.Constructor(dt, cf) => @@ -886,11 +1121,21 @@ object ProtoConverter { proto.Referant.Referant.Constructor( proto.ConstructorReference( proto.ConstructorReference.Value.LocalConstructor( - proto.ConstructorPtr(dtIdx + 1, cIdx + 1)))))) - } - else tabFail(new Exception(s"missing contructor for type $key, ${cf.name}, with local: $dt")) + proto.ConstructorPtr(dtIdx + 1, cIdx + 1) + ) + ) + ) + ) + ) + } else + tabFail( + new Exception( + s"missing contructor for type $key, ${cf.name}, with local: $dt" + ) + ) case None => - (getId(dt.packageName.asString), + ( + getId(dt.packageName.asString), getId(dt.name.ident.sourceCodeRepr), getId(cf.name.sourceCodeRepr) ).mapN { (pid, tid, cid) => @@ -898,12 +1143,19 @@ object ProtoConverter { proto.Referant.Referant.Constructor( proto.ConstructorReference( proto.ConstructorReference.Value.ImportedConstructor( - proto.ImportedConstructor(pid, tid, cid))))) - } + proto.ImportedConstructor(pid, tid, cid) + ) + ) + ) + ) + } } } - def expNameToProto[V](allDts: Map[(PackageName, TypeName), (DefinedType[Any], Int)], e: ExportedName[Referant[V]]): Tab[proto.ExportedName] = { + def expNameToProto[V]( + allDts: Map[(PackageName, TypeName), (DefinedType[Any], Int)], + e: ExportedName[Referant[V]] + ): Tab[proto.ExportedName] = { val protoRef: Tab[proto.Referant] = referantToProto(allDts, e.tag) val exKind: Tab[(Int, proto.ExportKind)] = e match { case ExportedName.Binding(b, _) => @@ -914,10 +1166,15 @@ object ProtoConverter { getId(n.asString).map((_, proto.ExportKind.ConstructorName)) } - (protoRef, exKind).mapN { case (ref, (idx, k)) => proto.ExportedName(k, idx, Some(ref)) } + (protoRef, exKind).mapN { case (ref, (idx, k)) => + proto.ExportedName(k, idx, Some(ref)) + } } - private def packageDeps(strings: Array[String], dt: proto.DefinedType): List[String] = + private def packageDeps( + strings: Array[String], + dt: proto.DefinedType + ): List[String] = dt.typeConst match { case Some(tc) => strings(tc.packageName - 1) :: Nil @@ -928,7 +1185,10 @@ object ProtoConverter { private def ifaceDeps(iface: proto.Interface): List[String] = { val ary = iface.strings.toArray val thisPack = ary(iface.packageName - 1) - iface.definedTypes.toList.flatMap(packageDeps(ary, _).filterNot(_ == thisPack)).distinct.sorted + iface.definedTypes.toList + .flatMap(packageDeps(ary, _).filterNot(_ == thisPack)) + .distinct + .sorted } // what package names does this full package depend on? @@ -943,16 +1203,16 @@ object ProtoConverter { } def interfaceToProto(iface: Package.Interface): Try[proto.Interface] = { - val allDts = DefinedType.listToMap( - iface.exports.flatMap { ex => + val allDts = DefinedType + .listToMap(iface.exports.flatMap { ex => /* * allDts are the locally defined types to this package * so we need to filter those outside this package */ - ex.tag - .definedType + ex.tag.definedType .filter(_.packageName == iface.name) - }).mapWithIndex { (dt, idx) => (dt, idx) } + }) + .mapWithIndex { (dt, idx) => (dt, idx) } val tryProtoDts = allDts .traverse { case (dt, _) => definedTypeToProto(dt) } @@ -965,15 +1225,25 @@ object ProtoConverter { val last = packageId.product(tryProtoDts).product(tryExports) runTab(last).map { case (serstate, ((nm, dts), exps)) => - proto.Interface(serstate.strings.inOrder, serstate.types.inOrder, dts, nm, exps) + proto.Interface( + serstate.strings.inOrder, + serstate.types.inOrder, + dts, + nm, + exps + ) } } - private def referantFromProto(loadDT: Type.Const => Try[DefinedType[Kind.Arg]], ref: proto.Referant): DTab[Referant[Kind.Arg]] = + private def referantFromProto( + loadDT: Type.Const => Try[DefinedType[Kind.Arg]], + ref: proto.Referant + ): DTab[Referant[Kind.Arg]] = ref.referant match { case proto.Referant.Referant.Value(t) => lookupType(t, s"invalid type in $ref").map(Referant.Value(_)) - case proto.Referant.Referant.DefinedType(proto.DefinedTypeReference(dt, _)) => + case proto.Referant.Referant + .DefinedType(proto.DefinedTypeReference(dt, _)) => dt match { case proto.DefinedTypeReference.Value.LocalDefinedTypePtr(idx) => lookupDts(idx, s"invalid defined type in $ref") @@ -985,31 +1255,43 @@ object ProtoConverter { case proto.DefinedTypeReference.Value.Empty => ReaderT.liftF(Failure(new Exception(s"empty referant found: $ref"))) } - case proto.Referant.Referant.Constructor(proto.ConstructorReference(consRef, _)) => + case proto.Referant.Referant + .Constructor(proto.ConstructorReference(consRef, _)) => consRef match { - case proto.ConstructorReference.Value.LocalConstructor(proto.ConstructorPtr(dtIdx, cIdx, _)) => + case proto.ConstructorReference.Value + .LocalConstructor(proto.ConstructorPtr(dtIdx, cIdx, _)) => lookupDts(dtIdx, s"invalid defined type in $ref").flatMap { dt => // cIdx is 1 based: val fixedIdx = cIdx - 1 ReaderT.liftF(dt.constructors.get(fixedIdx.toLong) match { case None => - Failure(new Exception(s"invalid constructor index: $cIdx in: $dt")) + Failure( + new Exception(s"invalid constructor index: $cIdx in: $dt") + ) case Some(cf) => Success(Referant.Constructor(dt, cf)) }) } - case proto.ConstructorReference.Value.ImportedConstructor(proto.ImportedConstructor(packId, typeId, consId, _)) => - (lookup(packId, s"imported constructor package in $ref"), + case proto.ConstructorReference.Value.ImportedConstructor( + proto.ImportedConstructor(packId, typeId, consId, _) + ) => + ( + lookup(packId, s"imported constructor package in $ref"), lookup(typeId, s"imported constructor typename in $ref"), - lookup(consId, s"imported constructor name in $ref")) - .tupled + lookup(consId, s"imported constructor name in $ref") + ).tupled .flatMapF { case (p, t, c) => for { tc <- typeConstFromStr(p, t, s"in $ref decoding ($p, $t)") dt <- loadDT(tc) cons <- toConstructor(c) idx = dt.constructors.indexWhere(_.name == cons) - _ <- if (idx < 0) Failure(new Exception(s"invalid constuctor name: $cons for $dt")) else Success(()) + _ <- + if (idx < 0) + Failure( + new Exception(s"invalid constuctor name: $cons for $dt") + ) + else Success(()) } yield Referant.Constructor(dt, dt.constructors(idx)) } case proto.ConstructorReference.Value.Empty => @@ -1020,14 +1302,17 @@ object ProtoConverter { } private def exportedNameFromProto( - loadDT: Type.Const => Try[DefinedType[Kind.Arg]], - en: proto.ExportedName): DTab[ExportedName[Referant[Kind.Arg]]] = { + loadDT: Type.Const => Try[DefinedType[Kind.Arg]], + en: proto.ExportedName + ): DTab[ExportedName[Referant[Kind.Arg]]] = { val tryRef: DTab[Referant[Kind.Arg]] = en.referant match { case Some(r) => referantFromProto(loadDT, r) - case None => ReaderT.liftF(Failure(new Exception(s"missing referant in $en"))) + case None => + ReaderT.liftF(Failure(new Exception(s"missing referant in $en"))) } - tryRef.product(lookup(en.name, en.toString)) + tryRef + .product(lookup(en.name, en.toString)) .flatMapF { case (ref, n) => en.exportKind match { case proto.ExportKind.Binding => @@ -1043,7 +1328,7 @@ object ProtoConverter { ExportedName.Constructor(c, ref) } case proto.ExportKind.Unrecognized(idx) => - Failure(new Exception(s"unknown export kind: $idx in $en")) + Failure(new Exception(s"unknown export kind: $idx in $en")) } } } @@ -1054,11 +1339,13 @@ object ProtoConverter { private sealed trait Scoped { def finish[A](dtab: DTab[A]): DTab[A] = this match { - case Scoped.Prep(d, fn) => d.flatMap { b => dtab.local[DecodeState] { ds => fn(ds, b) } } + case Scoped.Prep(d, fn) => + d.flatMap { b => dtab.local[DecodeState] { ds => fn(ds, b) } } } } private object Scoped { - case class Prep[A](dtab: DTab[A], fn: (DecodeState, A) => DecodeState) extends Scoped + case class Prep[A](dtab: DTab[A], fn: (DecodeState, A) => DecodeState) + extends Scoped def apply[A](dtab: DTab[A])(fn: (DecodeState, A) => DecodeState): Scoped = Prep(dtab, fn) @@ -1066,18 +1353,26 @@ object ProtoConverter { s.foldRight(dtab)(_.finish(_)) } - private def interfaceFromProto0(loadDT: Type.Const => Try[DefinedType[Kind.Arg]], protoIface: proto.Interface): Try[Package.Interface] = { + private def interfaceFromProto0( + loadDT: Type.Const => Try[DefinedType[Kind.Arg]], + protoIface: proto.Interface + ): Try[Package.Interface] = { val tab: DTab[Package.Interface] = for { packageName <- lookup(protoIface.packageName, protoIface.toString) pn <- ReaderT.liftF(parsePack(packageName, s"interface: $protoIface")) - exports <- protoIface.exports.toList.traverse(exportedNameFromProto(loadDT, _)) + exports <- protoIface.exports.toList.traverse( + exportedNameFromProto(loadDT, _) + ) } yield Package(pn, Nil, exports, ()) // build up the decoding state by decoding the tables in order - Scoped.run( - Scoped(buildTypes(protoIface.types))(_.withTypes(_)), - Scoped(protoIface.definedTypes.toVector.traverse(definedTypeFromProto))(_.withDefinedTypes(_)) + Scoped + .run( + Scoped(buildTypes(protoIface.types))(_.withTypes(_)), + Scoped(protoIface.definedTypes.toVector.traverse(definedTypeFromProto))( + _.withDefinedTypes(_) + ) )(tab) .run(DecodeState.init(protoIface.strings)) } @@ -1085,17 +1380,23 @@ object ProtoConverter { def interfaceFromProto(protoIface: proto.Interface): Try[Package.Interface] = interfacesFromProto(proto.Interfaces(protoIface :: Nil)).map(_.head) - def interfacesToProto[F[_]: Foldable](ps: F[Package.Interface]): Try[proto.Interfaces] = + def interfacesToProto[F[_]: Foldable]( + ps: F[Package.Interface] + ): Try[proto.Interfaces] = ps.toList.traverse(interfaceToProto).map { ifs => // sort so we are deterministic - proto.Interfaces(ifs.sortBy { iface => iface.strings(iface.packageName - 1) }) + proto.Interfaces(ifs.sortBy { iface => + iface.strings(iface.packageName - 1) + }) } def interfacesFromProto(ps: proto.Interfaces): Try[List[Package.Interface]] = // packagesFromProto can handle just interfaces as well packagesFromProto(ps.interfaces, Nil).map(_._1) - def read[A <: GeneratedMessage](path: Path)(implicit gmc: GeneratedMessageCompanion[A]): IO[A] = + def read[A <: GeneratedMessage]( + path: Path + )(implicit gmc: GeneratedMessageCompanion[A]): IO[A] = IO { val f = path.toFile val ios = new BufferedInputStream(new FileInputStream(f)) @@ -1115,16 +1416,21 @@ object ProtoConverter { } } - def readInterfacesAndPackages(ifacePaths: List[Path], packagePaths: List[Path]): IO[(List[Package.Interface], List[Package.Typed[Unit]])] = - (ifacePaths.traverse(read[proto.Interfaces](_)), - packagePaths.traverse(read[proto.Packages](_))) - .tupled + def readInterfacesAndPackages( + ifacePaths: List[Path], + packagePaths: List[Path] + ): IO[(List[Package.Interface], List[Package.Typed[Unit]])] = + ( + ifacePaths.traverse(read[proto.Interfaces](_)), + packagePaths.traverse(read[proto.Packages](_)) + ).tupled .flatMap { case (ifs, packs) => IO.fromTry( packagesFromProto( ifs.flatMap(_.interfaces), packs.flatMap(_.packages) - )) + ) + ) } def readInterfaces(paths: List[Path]): IO[List[Package.Interface]] = @@ -1133,7 +1439,10 @@ object ProtoConverter { def readPackages(paths: List[Path]): IO[List[Package.Typed[Unit]]] = readInterfacesAndPackages(Nil, paths).map(_._2) - def writeInterfaces(interfaces: List[Package.Interface], path: Path): IO[Unit] = + def writeInterfaces( + interfaces: List[Package.Interface], + path: Path + ): IO[Unit] = IO.fromTry(interfacesToProto(interfaces)) .flatMap(write(_, path)) @@ -1142,17 +1451,17 @@ object ProtoConverter { packages .traverse(packageToProto(_)) .map(proto.Packages(_)) - } - .flatMap(write(_, path)) + }.flatMap(write(_, path)) def importedNameToProto( - allDts: Map[(PackageName, TypeName), (DefinedType[Any], Int)], - in: ImportedName[NonEmptyList[Referant[Kind.Arg]]]): Tab[proto.ImportedName] = { + allDts: Map[(PackageName, TypeName), (DefinedType[Any], Int)], + in: ImportedName[NonEmptyList[Referant[Kind.Arg]]] + ): Tab[proto.ImportedName] = { val locName = in match { case ImportedName.OriginalName(_, _) => None - case ImportedName.Renamed(_, l, _) => Some(l) + case ImportedName.Renamed(_, l, _) => Some(l) } for { orig <- getId(in.originalName.sourceCodeRepr) @@ -1162,8 +1471,9 @@ object ProtoConverter { } def importToProto( - allDts: Map[(PackageName, TypeName), (DefinedType[Any], Int)], - i: Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]): Tab[proto.Imports] = + allDts: Map[(PackageName, TypeName), (DefinedType[Any], Int)], + i: Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]] + ): Tab[proto.Imports] = for { nm <- getId(i.pack.name.asString) imps <- i.items.toList.traverse(importedNameToProto(allDts, _)) @@ -1172,7 +1482,9 @@ object ProtoConverter { def letToProto(l: (Bindable, RecursionKind, TypedExpr[Any])): Tab[proto.Let] = for { nm <- getId(l._1.sourceCodeRepr) - rec = if (l._2.isRecursive) proto.RecursionKind.IsRec else proto.RecursionKind.NotRec + rec = + if (l._2.isRecursive) proto.RecursionKind.IsRec + else proto.RecursionKind.NotRec tex <- typedExprToProto(l._3) } yield proto.Let(nm, rec, tex) @@ -1185,7 +1497,8 @@ object ProtoConverter { def packageToProto[A](cpack: Package.Typed[A]): Try[proto.Package] = { // the Int is in index in the list of definedTypes: - val allDts: SortedMap[(PackageName, TypeName), (DefinedType[Kind.Arg], Int)] = + val allDts + : SortedMap[(PackageName, TypeName), (DefinedType[Kind.Arg], Int)] = cpack.program.types.definedTypes.mapWithIndex { (dt, idx) => (dt, idx) } val dtVect: Vector[DefinedType[Kind.Arg]] = allDts.values.iterator.map(_._1).toVector @@ -1196,20 +1509,23 @@ object ProtoConverter { exps <- cpack.exports.traverse(expNameToProto(allDts, _)) prog = cpack.program lets <- prog.lets.traverse(letToProto) - exdefs <- prog.externalDefs.traverse { nm => extDefToProto(nm, prog.types.getValue(cpack.name, nm)) } + exdefs <- prog.externalDefs.traverse { nm => + extDefToProto(nm, prog.types.getValue(cpack.name, nm)) + } dts <- dtVect.traverse(definedTypeToProto) } yield { (ss: SerState) => - proto.Package( - strings = ss.strings.inOrder, - types = ss.types.inOrder, - definedTypes = dts, - patterns = ss.patterns.inOrder, - expressions = ss.expressions.inOrder, - packageName = nmId, - imports = imps, - exports = exps, - lets = lets, - externalDefs = exdefs) + proto.Package( + strings = ss.strings.inOrder, + types = ss.types.inOrder, + definedTypes = dts, + patterns = ss.patterns.inOrder, + expressions = ss.expressions.inOrder, + packageName = nmId, + imports = imps, + exports = exps, + lets = lets, + externalDefs = exdefs + ) } runTab(tab).map { case (ss, fn) => fn(ss) } @@ -1219,7 +1535,8 @@ object ProtoConverter { Success(Identifier.Name("$anon")) def toBindable(str: String): Try[Bindable] = - if (str == "$anon") anonBind // used in Expr to create some lambdas with pattern match + if (str == "$anon") + anonBind // used in Expr to create some lambdas with pattern match else Try(Identifier.unsafeParse(Identifier.bindableParser, str)) def toIdent(str: String): Try[Identifier] = @@ -1236,19 +1553,22 @@ object ProtoConverter { lookup(idx, context).flatMapF(toIdent) def importedNameFromProto( - loadDT: Type.Const => Try[DefinedType[Kind.Arg]], - iname: proto.ImportedName): DTab[ImportedName[NonEmptyList[Referant[Kind.Arg]]]] = { + loadDT: Type.Const => Try[DefinedType[Kind.Arg]], + iname: proto.ImportedName + ): DTab[ImportedName[NonEmptyList[Referant[Kind.Arg]]]] = { def build[A](orig: Identifier, ref: A): DTab[ImportedName[A]] = if (iname.localName == 0) { ReaderT.pure(ImportedName.OriginalName(originalName = orig, ref)) - } - else { + } else { lookupIdentifier(iname.localName, iname.toString) .map(ImportedName.Renamed(originalName = orig, _, ref)) } NonEmptyList.fromList(iname.referant.toList) match { - case None => ReaderT.liftF(Failure(new Exception(s"expected at least one imported name: $iname"))) + case None => + ReaderT.liftF( + Failure(new Exception(s"expected at least one imported name: $iname")) + ) case Some(refs) => for { orig <- lookupIdentifier(iname.originalName, iname.toString) @@ -1258,11 +1578,16 @@ object ProtoConverter { } } - def importsFromProto(imp: proto.Imports, - lookupIface: PackageName => Try[Package.Interface], - loadDT: Type.Const => Try[DefinedType[Kind.Arg]]): DTab[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]] = + def importsFromProto( + imp: proto.Imports, + lookupIface: PackageName => Try[Package.Interface], + loadDT: Type.Const => Try[DefinedType[Kind.Arg]] + ): DTab[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]] = NonEmptyList.fromList(imp.names.toList) match { - case None => ReaderT.liftF(Failure(new Exception(s"expected non-empty import names in: $imp"))) + case None => + ReaderT.liftF( + Failure(new Exception(s"expected non-empty import names in: $imp")) + ) case Some(nei) => for { pnameStr <- lookup(imp.packageName, imp.toString) @@ -1272,20 +1597,30 @@ object ProtoConverter { } yield Import(iface, inames) } - def letsFromProto(let: proto.Let): DTab[(Bindable, RecursionKind, TypedExpr[Unit])] = - (lookupBindable(let.name, let.toString), - ReaderT.liftF(recursionKindFromProto(let.rec, let.toString)): DTab[RecursionKind], - lookupExpr(let.expr, let.toString)).tupled + def letsFromProto( + let: proto.Let + ): DTab[(Bindable, RecursionKind, TypedExpr[Unit])] = + ( + lookupBindable(let.name, let.toString), + ReaderT.liftF(recursionKindFromProto(let.rec, let.toString)): DTab[ + RecursionKind + ], + lookupExpr(let.expr, let.toString) + ).tupled def externalDefsFromProto(ed: proto.ExternalDef): DTab[(Bindable, Type)] = - (lookupBindable(ed.name, ed.toString), - lookupType(ed.typeOf, ed.toString)).tupled + ( + lookupBindable(ed.name, ed.toString), + lookupType(ed.typeOf, ed.toString) + ).tupled def buildProgram( - pack: PackageName, - lets: List[(Bindable, RecursionKind, TypedExpr[Unit])], - exts: List[(Bindable, Type)]): DTab[Program[TypeEnv[Kind.Arg], TypedExpr[Unit], Unit]] = - ReaderT.ask[Try, DecodeState] + pack: PackageName, + lets: List[(Bindable, RecursionKind, TypedExpr[Unit])], + exts: List[(Bindable, Type)] + ): DTab[Program[TypeEnv[Kind.Arg], TypedExpr[Unit], Unit]] = + ReaderT + .ask[Try, DecodeState] .map { ds => // this adds all the types and contructors // from the given defined types @@ -1299,21 +1634,24 @@ object ProtoConverter { } def packagesFromProto( - ifaces: Iterable[proto.Interface], - packs: Iterable[proto.Package]): Try[(List[Package.Interface], List[Package.Typed[Unit]])] = { + ifaces: Iterable[proto.Interface], + packs: Iterable[proto.Package] + ): Try[(List[Package.Interface], List[Package.Typed[Unit]])] = { type Node = Either[proto.Interface, proto.Package] def iname(p: proto.Interface): String = - p.strings.lift(p.packageName - 1) + p.strings + .lift(p.packageName - 1) .getOrElse("_unknown_" + p.packageName.toString) def pname(p: proto.Package): String = - p.strings.lift(p.packageName - 1) + p.strings + .lift(p.packageName - 1) .getOrElse("_unknown_" + p.packageName.toString) def nodeName(n: Node): String = n match { - case Left(i) => iname(i) + case Left(i) => iname(i) case Right(p) => pname(p) } @@ -1323,17 +1661,18 @@ object ProtoConverter { (l, r) match { case (Left(_), Right(_)) => -1 case (Right(_), Left(_)) => 1 - case (nl, nr) => nodeName(nl).compareTo(nodeName(nr)) + case (nl, nr) => nodeName(nl).compareTo(nodeName(nr)) } } - val nodes: List[Node] = ifaces.map(Left(_)).toList ::: packs.map(Right(_)).toList + val nodes: List[Node] = + ifaces.map(Left(_)).toList ::: packs.map(Right(_)).toList val nodeMap: Map[String, List[Node]] = nodes.groupBy(nodeName) def getNodes(n: String, parent: Node): List[Node] = nodeMap.get(n) match { - case Some(ns) => ns + case Some(ns) => ns case None if n == PackageName.PredefName.asString => // we can load the predef below Nil @@ -1345,132 +1684,170 @@ object ProtoConverter { // so, the unsafe calls inside are checked before we call def dependsOn(n: Node): List[Node] = n match { - case Left(i) => ifaceDeps(i).flatMap { dep => getNodes(dep, n) } + case Left(i) => ifaceDeps(i).flatMap { dep => getNodes(dep, n) } case Right(p) => packageDeps(p).flatMap { dep => getNodes(dep, n) } } val dupNames: List[String] = - nodeMap - .iterator + nodeMap.iterator .filter { case (_, vs) => vs.lengthCompare(1) > 0 } .map(_._1) .toList .sorted Try(graph.Toposort.sort(nodes)(dependsOn)).flatMap { sorted => + if (dupNames.nonEmpty) { + Failure( + new Exception("duplicate package names: " + dupNames.mkString(", ")) + ) + } else if (sorted.isFailure) { + val loopStr = + sorted.loopNodes + .map { + case Left(i) => "interface: " + iname(i) + case Right(p) => "compiled: " + pname(p) + } + .mkString(", ") + Failure(new Exception(s"circular dependencies in packages: $loopStr")) + } else { + def makeLoadDT( + load: String => Try[Either[ + (Package.Interface, TypeEnv[Kind.Arg]), + Package.Typed[Unit] + ]] + ): Type.Const => Try[DefinedType[Kind.Arg]] = { + case tc @ Type.Const.Defined(p, _) => + val res = load(p.asString).map { + case Left((_, dt)) => + dt.toDefinedType(tc) + case Right(comp) => + comp.program.types.toDefinedType(tc) + } - if (dupNames.nonEmpty) { - Failure(new Exception("duplicate package names: " + dupNames.mkString(", "))) - } - else if (sorted.isFailure) { - val loopStr = - sorted - .loopNodes - .map { - case Left(i) => "interface: " + iname(i) - case Right(p) => "compiled: " + pname(p) - } - .mkString(", ") - Failure(new Exception(s"circular dependencies in packages: $loopStr")) - } - else { - def makeLoadDT( - load: String => Try[Either[(Package.Interface, TypeEnv[Kind.Arg]), Package.Typed[Unit]]] - ): Type.Const => Try[DefinedType[Kind.Arg]] = { case tc@Type.Const.Defined(p, _) => - val res = load(p.asString).map { - case Left((_, dt)) => - dt.toDefinedType(tc) - case Right(comp) => - comp.program.types.toDefinedType(tc) + res.flatMap { + case None => + Failure(new Exception(s"unknown type $tc not present")) + case Some(dt) => Success(dt) + } } - res.flatMap { - case None => Failure(new Exception(s"unknown type $tc not present")) - case Some(dt) => Success(dt) - } - } + /* + * We know we have a dag now, so we can just go through + * loading them. + * + * We will need a list of these an memoize loading them all + */ - /* - * We know we have a dag now, so we can just go through - * loading them. - * - * We will need a list of these an memoize loading them all - */ - - def packFromProtoUncached( - pack: proto.Package, - load: String => Try[Either[(Package.Interface, TypeEnv[Kind.Arg]), Package.Typed[Unit]]] - ): Try[Package.Typed[Unit]] = { - val loadIface: PackageName => Try[Package.Interface] = { p => - load(p.asString).map { - case Left((iface, _)) => iface - case Right(pack) => Package.interfaceOf(pack) + def packFromProtoUncached( + pack: proto.Package, + load: String => Try[Either[ + (Package.Interface, TypeEnv[Kind.Arg]), + Package.Typed[Unit] + ]] + ): Try[Package.Typed[Unit]] = { + val loadIface: PackageName => Try[Package.Interface] = { p => + load(p.asString).map { + case Left((iface, _)) => iface + case Right(pack) => Package.interfaceOf(pack) + } } - } - val loadDT = makeLoadDT(load) - - val tab: DTab[Package.Typed[Unit]] = - for { - packageNameStr <- lookup(pack.packageName, pack.toString) - packageName <- ReaderT.liftF(parsePack(packageNameStr, pack.toString)) - imps <- pack.imports.toList.traverse(importsFromProto(_, loadIface, loadDT)) - exps <- pack.exports.toList.traverse(exportedNameFromProto(loadDT, _)) - lets <- pack.lets.toList.traverse(letsFromProto) - eds <- pack.externalDefs.toList.traverse(externalDefsFromProto) - prog <- buildProgram(packageName, lets, eds) - } yield Package(packageName, imps, exps, prog) - - // build up the decoding state by decoding the tables in order - val tab1 = Scoped.run( - Scoped(buildTypes(pack.types))(_.withTypes(_)), - Scoped(pack.definedTypes.toVector.traverse(definedTypeFromProto))(_.withDefinedTypes(_)), - Scoped(buildPatterns(pack.patterns))(_.withPatterns(_)), - Scoped(buildExprs(pack.expressions))(_.withExprs(_)) + val loadDT = makeLoadDT(load) + + val tab: DTab[Package.Typed[Unit]] = + for { + packageNameStr <- lookup(pack.packageName, pack.toString) + packageName <- ReaderT.liftF( + parsePack(packageNameStr, pack.toString) + ) + imps <- pack.imports.toList.traverse( + importsFromProto(_, loadIface, loadDT) + ) + exps <- pack.exports.toList.traverse( + exportedNameFromProto(loadDT, _) + ) + lets <- pack.lets.toList.traverse(letsFromProto) + eds <- pack.externalDefs.toList.traverse(externalDefsFromProto) + prog <- buildProgram(packageName, lets, eds) + } yield Package(packageName, imps, exps, prog) + + // build up the decoding state by decoding the tables in order + val tab1 = Scoped.run( + Scoped(buildTypes(pack.types))(_.withTypes(_)), + Scoped(pack.definedTypes.toVector.traverse(definedTypeFromProto))( + _.withDefinedTypes(_) + ), + Scoped(buildPatterns(pack.patterns))(_.withPatterns(_)), + Scoped(buildExprs(pack.expressions))(_.withExprs(_)) )(tab) - tab1.run(DecodeState.init(pack.strings)) - } + tab1.run(DecodeState.init(pack.strings)) + } - val predefIface = { - val iface = Package.interfaceOf(PackageMap.predefCompiled) - (iface, ExportedName.typeEnvFromExports(iface.name, iface.exports)) - } + val predefIface = { + val iface = Package.interfaceOf(PackageMap.predefCompiled) + (iface, ExportedName.typeEnvFromExports(iface.name, iface.exports)) + } - val load: String => Try[Either[(Package.Interface, TypeEnv[Kind.Arg]), Package.Typed[Unit]]] = - Memoize.memoizeDagHashed[String, Try[Either[(Package.Interface, TypeEnv[Kind.Arg]), Package.Typed[Unit]]]] { (pack, rec) => - nodeMap.get(pack) match { - case Some(Left(iface) :: Nil) => - interfaceFromProto0(makeLoadDT(rec), iface) - .map { iface => Left((iface, ExportedName.typeEnvFromExports(iface.name, iface.exports))) } - case Some(Right(p) :: Nil) => - packFromProtoUncached(p, rec) - .map(Right(_)) - case None if pack == PackageName.PredefName.asString => - // if we haven't replaced explicitly, use the built in predef - Success(Left(predefIface)) - case found => - Failure(new Exception(s"missing interface or compiled: $pack, found: $found")) + val load: String => Try[ + Either[(Package.Interface, TypeEnv[Kind.Arg]), Package.Typed[Unit]] + ] = + Memoize.memoizeDagHashed[String, Try[ + Either[(Package.Interface, TypeEnv[Kind.Arg]), Package.Typed[Unit]] + ]] { (pack, rec) => + nodeMap.get(pack) match { + case Some(Left(iface) :: Nil) => + interfaceFromProto0(makeLoadDT(rec), iface) + .map { iface => + Left( + ( + iface, + ExportedName.typeEnvFromExports( + iface.name, + iface.exports + ) + ) + ) + } + case Some(Right(p) :: Nil) => + packFromProtoUncached(p, rec) + .map(Right(_)) + case None if pack == PackageName.PredefName.asString => + // if we haven't replaced explicitly, use the built in predef + Success(Left(predefIface)) + case found => + Failure( + new Exception( + s"missing interface or compiled: $pack, found: $found" + ) + ) + } } - } - val deserPack: proto.Package => Try[Package.Typed[Unit]] = { p => - load(pname(p)).flatMap { - case Left((iface, _)) => Failure(new Exception(s"expected compiled for ${iface.name.asString}, found interface")) - case Right(pack) => Success(pack) + val deserPack: proto.Package => Try[Package.Typed[Unit]] = { p => + load(pname(p)).flatMap { + case Left((iface, _)) => + Failure( + new Exception( + s"expected compiled for ${iface.name.asString}, found interface" + ) + ) + case Right(pack) => Success(pack) + } } - } - val deserIface: proto.Interface => Try[Package.Interface] = { p => - load(iname(p)).map { - case Left((iface, _)) => iface - case Right(pack) => Package.interfaceOf(pack) + val deserIface: proto.Interface => Try[Package.Interface] = { p => + load(iname(p)).map { + case Left((iface, _)) => iface + case Right(pack) => Package.interfaceOf(pack) + } } - } - // use the cached versions down here - (ifaces.toList.traverse(deserIface), - packs.toList.traverse(deserPack)).tupled - } + // use the cached versions down here + ( + ifaces.toList.traverse(deserIface), + packs.toList.traverse(deserPack) + ).tupled + } } } } diff --git a/cli/src/test/scala/org/bykn/bosatsu/JsonTest.scala b/cli/src/test/scala/org/bykn/bosatsu/JsonTest.scala index 05682a633..616fa5558 100644 --- a/cli/src/test/scala/org/bykn/bosatsu/JsonTest.scala +++ b/cli/src/test/scala/org/bykn/bosatsu/JsonTest.scala @@ -1,6 +1,9 @@ package org.bykn.bosatsu -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.typelevel.jawn.ast.{JValue, JParser} import GenJson._ @@ -14,11 +17,11 @@ class JsonJawnTest extends AnyFunSuite { def matches(j1: Json, j2: JValue): Unit = { import Json._ j1 match { - case JString(str) => assert(j2.asString == str); () + case JString(str) => assert(j2.asString == str); () case JNumberStr(nstr) => assert(BigDecimal(nstr) == j2.asBigDecimal); () - case JNull => assert(j2.isNull); () - case JBool.True => assert(j2.asBoolean); () - case JBool.False => assert(!j2.asBoolean); () + case JNull => assert(j2.isNull); () + case JBool.True => assert(j2.asBoolean); () + case JBool.False => assert(!j2.asBoolean); () case JArray(js) => js.zipWithIndex.foreach { case (j, idx) => matches(j, j2.get(idx)) diff --git a/cli/src/test/scala/org/bykn/bosatsu/PathModuleTest.scala b/cli/src/test/scala/org/bykn/bosatsu/PathModuleTest.scala index 4027700cf..06f0b04e4 100644 --- a/cli/src/test/scala/org/bykn/bosatsu/PathModuleTest.scala +++ b/cli/src/test/scala/org/bykn/bosatsu/PathModuleTest.scala @@ -26,11 +26,32 @@ class PathModuleTest extends AnyFunSuite { def pn(roots: List[String], file: String): Option[PackageName] = PathModule.pathPackage(roots.map(Paths.get(_)), Paths.get(file)) - assert(pn(List("/root0", "/root1"), "/root0/Bar.bosatsu") == Some(PackageName(NonEmptyList.of("Bar")))) - assert(pn(List("/root0", "/root1"), "/root1/Bar/Baz.bosatsu") == Some(PackageName(NonEmptyList.of("Bar", "Baz")))) - assert(pn(List("/root0", "/root0/Bar"), "/root0/Bar/Baz.bosatsu") == Some(PackageName(NonEmptyList.of("Bar", "Baz")))) - assert(pn(List("/root0/", "/root0/Bar"), "/root0/Bar/Baz.bosatsu") == Some(PackageName(NonEmptyList.of("Bar", "Baz")))) - assert(pn(List("/root0/ext", "/root0/Bar"), "/root0/ext/Bar/Baz.bosatsu") == Some(PackageName(NonEmptyList.of("Bar", "Baz")))) + assert( + pn(List("/root0", "/root1"), "/root0/Bar.bosatsu") == Some( + PackageName(NonEmptyList.of("Bar")) + ) + ) + assert( + pn(List("/root0", "/root1"), "/root1/Bar/Baz.bosatsu") == Some( + PackageName(NonEmptyList.of("Bar", "Baz")) + ) + ) + assert( + pn(List("/root0", "/root0/Bar"), "/root0/Bar/Baz.bosatsu") == Some( + PackageName(NonEmptyList.of("Bar", "Baz")) + ) + ) + assert( + pn(List("/root0/", "/root0/Bar"), "/root0/Bar/Baz.bosatsu") == Some( + PackageName(NonEmptyList.of("Bar", "Baz")) + ) + ) + assert( + pn( + List("/root0/ext", "/root0/Bar"), + "/root0/ext/Bar/Baz.bosatsu" + ) == Some(PackageName(NonEmptyList.of("Bar", "Baz"))) + ) } test("no roots means no Package") { @@ -50,7 +71,9 @@ class PathModuleTest extends AnyFunSuite { if (rest.toString != "" && root.toString != "") { val path = root.resolve(rest) val pack = - PackageName.parse(rest.asScala.map(_.toString.toLowerCase.capitalize).mkString("/")) + PackageName.parse( + rest.asScala.map(_.toString.toLowerCase.capitalize).mkString("/") + ) assert(PathModule.pathPackage(root :: otherRoots, path) == pack) } } @@ -60,7 +83,8 @@ class PathModuleTest extends AnyFunSuite { val regressions: List[(Path, List[Path], Path)] = List( (Paths.get(""), Nil, Paths.get("/foo/bar")), - (Paths.get(""), List(Paths.get("")), Paths.get("/foo/bar"))) + (Paths.get(""), List(Paths.get("")), Paths.get("/foo/bar")) + ) regressions.foreach { case (r, o, e) => law(r, o, e) } } @@ -70,7 +94,9 @@ class PathModuleTest extends AnyFunSuite { val roots = (r0 :: roots0).filterNot(_.toString == "") val pack = PathModule.pathPackage(roots, file) - val noPrefix = !roots.exists { r => file.asScala.toList.startsWith(r.asScala.toList) } + val noPrefix = !roots.exists { r => + file.asScala.toList.startsWith(r.asScala.toList) + } if (noPrefix) assert(pack == None) } @@ -80,30 +106,43 @@ class PathModuleTest extends AnyFunSuite { PathModule.run(args.toList) match { case Left(h) => fail(s"got help: $h on command: ${args.toList}") case Right(io) => - io.attempt.flatMap { - case Right(out) => - PathModule.reportOutput(out).as(out) - case Left(err) => - PathModule.reportException(err) *> IO.raiseError(err) - } - .unsafeRunSync() + io.attempt + .flatMap { + case Right(out) => + PathModule.reportOutput(out).as(out) + case Left(err) => + PathModule.reportException(err) *> IO.raiseError(err) + } + .unsafeRunSync() } test("test direct run of a file") { - val out = run("test --input test_workspace/List.bosatsu --input test_workspace/Nat.bosatsu --input test_workspace/Bool.bosatsu --test_file test_workspace/Queue.bosatsu".split("\\s+").toSeq: _*) + val out = run( + "test --input test_workspace/List.bosatsu --input test_workspace/Nat.bosatsu --input test_workspace/Bool.bosatsu --test_file test_workspace/Queue.bosatsu" + .split("\\s+") + .toSeq: _* + ) out match { case PathModule.Output.TestOutput(results, _) => - val res = results.collect { case (pn, Some(t)) if pn.asString == "Queue" => t.value } + val res = results.collect { + case (pn, Some(t)) if pn.asString == "Queue" => t.value + } assert(res.length == 1) case other => fail(s"expected test output: $other") } } test("test search run of a file") { - val out = run("test --package_root test_workspace --search --test_file test_workspace/Bar.bosatsu".split("\\s+").toSeq: _*) + val out = run( + "test --package_root test_workspace --search --test_file test_workspace/Bar.bosatsu" + .split("\\s+") + .toSeq: _* + ) out match { case PathModule.Output.TestOutput(results, _) => - val res = results.collect { case (pn, Some(t)) if pn.asString == "Bar" => t.value } + val res = results.collect { + case (pn, Some(t)) if pn.asString == "Bar" => t.value + } assert(res.length == 1) assert(res.head.assertions == 1) assert(res.head.failureCount == 0) @@ -112,7 +151,11 @@ class PathModuleTest extends AnyFunSuite { } test("test python transpile on the entire test_workspace") { - val out = run("transpile --input_dir test_workspace/ --outdir pyout --lang python --package_root test_workspace".split("\\s+").toSeq: _*) + val out = run( + "transpile --input_dir test_workspace/ --outdir pyout --lang python --package_root test_workspace" + .split("\\s+") + .toSeq: _* + ) out match { case PathModule.Output.TranspileOut(_, _) => assert(true) @@ -122,18 +165,29 @@ class PathModuleTest extends AnyFunSuite { test("test search with json write") { - val out = run("json write --package_root test_workspace --search --main_file test_workspace/Bar.bosatsu".split("\\s+").toSeq: _*) + val out = run( + "json write --package_root test_workspace --search --main_file test_workspace/Bar.bosatsu" + .split("\\s+") + .toSeq: _* + ) out match { - case PathModule.Output.JsonOutput(j@Json.JObject(_), _) => - assert(j.toMap == Map("value" -> Json.JBool(true), "message" -> Json.JString("got the right string"))) + case PathModule.Output.JsonOutput(j @ Json.JObject(_), _) => + assert( + j.toMap == Map( + "value" -> Json.JBool(true), + "message" -> Json.JString("got the right string") + ) + ) assert(j.items.length == 2) case other => fail(s"expected json object output: $other") } } test("test search json apply") { - val cmd = "json apply --input_dir test_workspace/ --package_root test_workspace/ --main Bosatsu/Nat::mult --json_string" - .split("\\s+").toList :+ "[2, 4]" + val cmd = + "json apply --input_dir test_workspace/ --package_root test_workspace/ --main Bosatsu/Nat::mult --json_string" + .split("\\s+") + .toList :+ "[2, 4]" run(cmd: _*) match { case PathModule.Output.JsonOutput(Json.JNumberStr("8"), _) => succeed @@ -142,11 +196,17 @@ class PathModuleTest extends AnyFunSuite { } test("test search json traverse") { - val cmd = "json traverse --input_dir test_workspace/ --package_root test_workspace/ --main Bosatsu/Nat::mult --json_string" - .split("\\s+").toList :+ "[[2, 4], [3, 5]]" + val cmd = + "json traverse --input_dir test_workspace/ --package_root test_workspace/ --main Bosatsu/Nat::mult --json_string" + .split("\\s+") + .toList :+ "[[2, 4], [3, 5]]" run(cmd: _*) match { - case PathModule.Output.JsonOutput(Json.JArray(Vector(Json.JNumberStr("8"), Json.JNumberStr("15"))), _) => succeed + case PathModule.Output.JsonOutput( + Json.JArray(Vector(Json.JNumberStr("8"), Json.JNumberStr("15"))), + _ + ) => + succeed case other => fail(s"expected json object output: $other") } } @@ -163,7 +223,8 @@ class PathModuleTest extends AnyFunSuite { } // ill-typed json fails - val cmd = "json apply --input_dir test_workspace/ --package_root test_workspace/ --main Bosatsu/Nat::mult --json_string" + val cmd = + "json apply --input_dir test_workspace/ --package_root test_workspace/ --main Bosatsu/Nat::mult --json_string" fails(cmd, "[\"2\", 4]") fails(cmd, "[2, \"4\"]") // wrong arity @@ -171,42 +232,61 @@ class PathModuleTest extends AnyFunSuite { fails(cmd, "[2]") fails(cmd, "[]") // unknown command fails - val badName = "json apply --input_dir test_workspace/ --package_root test_workspace/ --main Bosatsu/Nat::foooooo --json_string 23" + val badName = + "json apply --input_dir test_workspace/ --package_root test_workspace/ --main Bosatsu/Nat::foooooo --json_string 23" fails(badName) - val badPack = "json apply --input_dir test_workspace/ --package_root test_workspace/ --main Bosatsu/DoesNotExist --json_string 23" + val badPack = + "json apply --input_dir test_workspace/ --package_root test_workspace/ --main Bosatsu/DoesNotExist --json_string 23" fails(badPack) // bad json fails fails(cmd, "[\"2\", foo, bla]") fails(cmd, "[42, 31] and some junk") // exercise unsupported, we cannot write mult, it is a function - fails("json write --input_dir test_workspace/ --package_root test_workspace/ --main Bosatsu/Nat::mult") + fails( + "json write --input_dir test_workspace/ --package_root test_workspace/ --main Bosatsu/Nat::mult" + ) // a bad main name triggers help - PathModule.run("json write --input_dir test_workspace --main Bo//".split(' ').toList) match { - case Left(_) => succeed + PathModule.run( + "json write --input_dir test_workspace --main Bo//".split(' ').toList + ) match { + case Left(_) => succeed case Right(_) => fail() } - PathModule.run("json write --input_dir test_workspace --main Bo:::boop".split(' ').toList) match { - case Left(_) => succeed + PathModule.run( + "json write --input_dir test_workspace --main Bo:::boop".split(' ').toList + ) match { + case Left(_) => succeed case Right(_) => fail() } } test("test running all test in test_workspace") { - val out = run("test --package_root test_workspace --input_dir test_workspace".split("\\s+").toSeq: _*) + val out = run( + "test --package_root test_workspace --input_dir test_workspace" + .split("\\s+") + .toSeq: _* + ) out match { case PathModule.Output.TestOutput(res, _) => val noTests = res.collect { case (pn, None) => pn }.toList assert(noTests == Nil) - val failures = res.collect { case (pn, Some(t)) if t.value.failureCount > 0 => pn } + val failures = res.collect { + case (pn, Some(t)) if t.value.failureCount > 0 => pn + } assert(failures == Nil) case other => fail(s"expected test output: $other") } } test("evaluation by name with shadowing") { - run("json write --package_root test_workspace --input test_workspace/Foo.bosatsu --main Foo::x".split("\\s+").toSeq: _*) match { - case PathModule.Output.JsonOutput(Json.JString("this is Foo"), _) => succeed + run( + "json write --package_root test_workspace --input test_workspace/Foo.bosatsu --main Foo::x" + .split("\\s+") + .toSeq: _* + ) match { + case PathModule.Output.JsonOutput(Json.JString("this is Foo"), _) => + succeed case other => fail(s"unexpeced: $other") } } diff --git a/cli/src/test/scala/org/bykn/bosatsu/TestProtoType.scala b/cli/src/test/scala/org/bykn/bosatsu/TestProtoType.scala index f0e199a0f..cca7a98b2 100644 --- a/cli/src/test/scala/org/bykn/bosatsu/TestProtoType.scala +++ b/cli/src/test/scala/org/bykn/bosatsu/TestProtoType.scala @@ -5,7 +5,10 @@ import cats.Eq import cats.effect.{IO, Resource} import org.bykn.bosatsu.rankn.Type import org.scalacheck.Gen -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import scala.util.{Failure, Success, Try} import cats.implicits._ @@ -17,9 +20,9 @@ import org.scalatest.funsuite.AnyFunSuite class TestProtoType extends AnyFunSuite with ParTest { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 5000) + // PropertyCheckConfiguration(minSuccessful = 5000) PropertyCheckConfiguration(minSuccessful = 100) - //PropertyCheckConfiguration(minSuccessful = 5) + // PropertyCheckConfiguration(minSuccessful = 5) def law[A: Eq, B](a: A, fn: A => Try[B], gn: B => Try[A]) = { val maybeProto = fn(a) @@ -39,12 +42,16 @@ class TestProtoType extends AnyFunSuite with ParTest { .zip(orig.toString) .zipWithIndex .dropWhile { case ((a, b), _) => a == b } - .headOption.map(_._2) + .headOption + .map(_._2) .getOrElse(0) val context = 100 - assert(Eq[A].eqv(a, orig), s"${a.toString.drop(diffIdx - context/2).take(context)} != ${orig.toString.drop(diffIdx - context/2).take(context)}") - //assert(Eq[A].eqv(a, orig), s"$a\n\n!=\n\n$orig") + assert( + Eq[A].eqv(a, orig), + s"${a.toString.drop(diffIdx - context / 2).take(context)} != ${orig.toString.drop(diffIdx - context / 2).take(context)}" + ) + // assert(Eq[A].eqv(a, orig), s"$a\n\n!=\n\n$orig") } def testWithTempFile(fn: Path => IO[Unit]): Unit = { @@ -63,7 +70,9 @@ class TestProtoType extends AnyFunSuite with ParTest { tempRes.use(fn).unsafeRunSync() } - def tabLaw[A: Eq, B](f: A => ProtoConverter.Tab[B])(g: (ProtoConverter.SerState, B) => ProtoConverter.DTab[A]) = { (a: A) => + def tabLaw[A: Eq, B]( + f: A => ProtoConverter.Tab[B] + )(g: (ProtoConverter.SerState, B) => ProtoConverter.DTab[A]) = { (a: A) => f(a).run(ProtoConverter.SerState.empty) match { case Success((ss, b)) => val ds = ProtoConverter.DecodeState.init(ss.strings.inOrder) @@ -87,10 +96,14 @@ class TestProtoType extends AnyFunSuite with ParTest { } test("we can roundtrip patterns through proto") { - val testFn = tabLaw(ProtoConverter.patternToProto(_: Pattern[(PackageName, Constructor), Type])) { (ss, idx) => + val testFn = tabLaw( + ProtoConverter.patternToProto( + _: Pattern[(PackageName, Constructor), Type] + ) + ) { (ss, idx) => for { tps <- ProtoConverter.buildTypes(ss.types.inOrder) - pats = ProtoConverter.buildPatterns(ss.patterns.inOrder).map(_(idx - 1)) + pats = ProtoConverter.buildPatterns(ss.patterns.inOrder).map(_(idx - 1)) res <- pats.local[ProtoConverter.DecodeState](_.withTypes(tps)) } yield res }(Eq.fromUniversalEquals) @@ -99,22 +112,33 @@ class TestProtoType extends AnyFunSuite with ParTest { } test("we can roundtrip TypedExpr through proto") { - val testFn = tabLaw(ProtoConverter.typedExprToProto(_: TypedExpr[Unit])) { (ss, idx) => - for { - tps <- ProtoConverter.buildTypes(ss.types.inOrder) - pats = ProtoConverter.buildPatterns(ss.patterns.inOrder) - patTab <- pats.local[ProtoConverter.DecodeState](_.withTypes(tps)) - expr = ProtoConverter.buildExprs(ss.expressions.inOrder).map(_(idx - 1)) - res <- expr.local[ProtoConverter.DecodeState](_.withTypes(tps).withPatterns(patTab)) - } yield res + val testFn = tabLaw(ProtoConverter.typedExprToProto(_: TypedExpr[Unit])) { + (ss, idx) => + for { + tps <- ProtoConverter.buildTypes(ss.types.inOrder) + pats = ProtoConverter.buildPatterns(ss.patterns.inOrder) + patTab <- pats.local[ProtoConverter.DecodeState](_.withTypes(tps)) + expr = ProtoConverter + .buildExprs(ss.expressions.inOrder) + .map(_(idx - 1)) + res <- expr.local[ProtoConverter.DecodeState]( + _.withTypes(tps).withPatterns(patTab) + ) + } yield res }(Eq.fromUniversalEquals) - forAll(Generators.genTypedExpr(Gen.const(()), 4, rankn.NTypeGen.genDepth03))(testFn) + forAll( + Generators.genTypedExpr(Gen.const(()), 4, rankn.NTypeGen.genDepth03) + )(testFn) } test("we can roundtrip interface through proto") { forAll(Generators.interfaceGen) { iface => - law(iface, ProtoConverter.interfaceToProto _, ProtoConverter.interfaceFromProto _)(Eq.fromUniversalEquals) + law( + iface, + ProtoConverter.interfaceToProto _, + ProtoConverter.interfaceFromProto _ + )(Eq.fromUniversalEquals) } } @@ -127,49 +151,71 @@ class TestProtoType extends AnyFunSuite with ParTest { } test("we can roundtrip interfaces through proto") { - forAll(Generators.smallDistinctByList(Generators.interfaceGen)(_.name)) { ifaces => - law(ifaces, ProtoConverter.interfacesToProto[List] _, ProtoConverter.interfacesFromProto _)(sortedEq) + forAll(Generators.smallDistinctByList(Generators.interfaceGen)(_.name)) { + ifaces => + law( + ifaces, + ProtoConverter.interfacesToProto[List] _, + ProtoConverter.interfacesFromProto _ + )(sortedEq) } } test("we can roundtrip interfaces from full packages through proto") { forAll(Generators.genPackage(Gen.const(()), 10)) { packMap => - val ifaces = packMap.iterator.map { case (_, p) => Package.interfaceOf(p) }.toList - law(ifaces, ProtoConverter.interfacesToProto[List] _, ProtoConverter.interfacesFromProto _)(sortedEq) + val ifaces = packMap.iterator.map { case (_, p) => + Package.interfaceOf(p) + }.toList + law( + ifaces, + ProtoConverter.interfacesToProto[List] _, + ProtoConverter.interfacesFromProto _ + )(sortedEq) } } test("we can roundtrip interfaces through file") { - forAll(Generators.smallDistinctByList(Generators.interfaceGen)(_.name)) { ifaces => - testWithTempFile { path => - for { - _ <- ProtoConverter.writeInterfaces(ifaces, path) - ifaces1 <- ProtoConverter.readInterfaces(path :: Nil) - _ = assert(sortedEq.eqv(ifaces, ifaces1)) - } yield () - } + forAll(Generators.smallDistinctByList(Generators.interfaceGen)(_.name)) { + ifaces => + testWithTempFile { path => + for { + _ <- ProtoConverter.writeInterfaces(ifaces, path) + ifaces1 <- ProtoConverter.readInterfaces(path :: Nil) + _ = assert(sortedEq.eqv(ifaces, ifaces1)) + } yield () + } } } test("test some hand written packages") { - def ser(p: List[Package.Typed[Unit]]): Try[List[proto.Package]] = - p.traverse(ProtoConverter.packageToProto) - def deser(ps: List[proto.Package]): Try[List[Package.Typed[Unit]]] = - ProtoConverter.packagesFromProto(Nil, ps).map { case (_, p) => p.sortBy(_.name) } + def ser(p: List[Package.Typed[Unit]]): Try[List[proto.Package]] = + p.traverse(ProtoConverter.packageToProto) + def deser(ps: List[proto.Package]): Try[List[Package.Typed[Unit]]] = + ProtoConverter.packagesFromProto(Nil, ps).map { case (_, p) => + p.sortBy(_.name) + } val tf = Package.typedFunctor - TestUtils.testInferred(List( -"""package Foo + TestUtils.testInferred( + List( + """package Foo export bar bar = 1 """ - ), "Foo", { (packs, _) => - law(packs.toMap.values.toList.sortBy(_.name).map { pt => Package.setProgramFrom(tf.void(pt), ()) }, - ser _, - deser _)(Eq.fromUniversalEquals) - }) + ), + "Foo", + { (packs, _) => + law( + packs.toMap.values.toList.sortBy(_.name).map { pt => + Package.setProgramFrom(tf.void(pt), ()) + }, + ser _, + deser _ + )(Eq.fromUniversalEquals) + } + ) } test("we can roundtrip packages through proto") { @@ -177,7 +223,9 @@ bar = 1 def ser(p: List[Package.Typed[Unit]]): Try[List[proto.Package]] = p.traverse(ProtoConverter.packageToProto) def deser(ps: List[proto.Package]): Try[List[Package.Typed[Unit]]] = - ProtoConverter.packagesFromProto(Nil, ps).map { case (_, p) => p.sortBy(_.name) } + ProtoConverter.packagesFromProto(Nil, ps).map { case (_, p) => + p.sortBy(_.name) + } val packList = packMap.toList.sortBy(_._1).map(_._2) law(packList, ser _, deser _)(Eq.fromUniversalEquals) diff --git a/cli/src/test/scala/org/bykn/bosatsu/codegen/python/CodeTest.scala b/cli/src/test/scala/org/bykn/bosatsu/codegen/python/CodeTest.scala index 2b1882790..eae21c9d7 100644 --- a/cli/src/test/scala/org/bykn/bosatsu/codegen/python/CodeTest.scala +++ b/cli/src/test/scala/org/bykn/bosatsu/codegen/python/CodeTest.scala @@ -3,14 +3,17 @@ package org.bykn.bosatsu.codegen.python import cats.data.NonEmptyList import java.math.BigInteger import org.scalacheck.Gen -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.python.core.{ParserFacade => JythonParserFacade} import org.scalatest.funsuite.AnyFunSuite class CodeTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 50000) - //PropertyCheckConfiguration(minSuccessful = 5000) + // PropertyCheckConfiguration(minSuccessful = 50000) + // PropertyCheckConfiguration(minSuccessful = 5000) PropertyCheckConfiguration(minSuccessful = 500) lazy val genPy2Name: Gen[String] = { @@ -38,10 +41,15 @@ class CodeTest extends AnyFunSuite { Gen.oneOf( Gen.identifier.map(Code.PyString), genIdent, - Gen.oneOf(Code.Const.Zero, Code.Const.One, Code.Const.True, Code.Const.False), + Gen.oneOf( + Code.Const.Zero, + Code.Const.One, + Code.Const.True, + Code.Const.False + ), genDotselect, - Gen.choose(-1024, 1024).map(Code.fromInt)) - + Gen.choose(-1024, 1024).map(Code.fromInt) + ) if (depth <= 0) genZero else { @@ -57,9 +65,11 @@ class CodeTest extends AnyFunSuite { Code.Const.Eq, Code.Const.Neq, Code.Const.Gt, - Code.Const.Lt) + Code.Const.Lt + ) - val genOp = Gen.zip(rec, opName, rec).map { case (a, b, c) => Code.Op(a, b, c) } + val genOp = + Gen.zip(rec, opName, rec).map { case (a, b, c) => Code.Op(a, b, c) } val genTup = for { @@ -88,9 +98,19 @@ class CodeTest extends AnyFunSuite { (1, genOp), (2, rec.map(Code.Parens(_))), (2, Gen.zip(rec, Gen.choose(0, 100)).map { case (a, p) => a.get(p) }), - (1, Gen.zip(rec, Gen.option(rec), Gen.option(rec)).map { case (a, s, e) => Code.SelectRange(a, s, e) }), + ( + 1, + Gen.zip(rec, Gen.option(rec), Gen.option(rec)).map { case (a, s, e) => + Code.SelectRange(a, s, e) + } + ), (1, Gen.oneOf(genTup, genList)), // these can really blow things up - (2, Gen.zip(Gen.listOf(genIdent), rec).map { case (args, x) => Code.Lambda(args, x) }), + ( + 2, + Gen.zip(Gen.listOf(genIdent), rec).map { case (args, x) => + Code.Lambda(args, x) + } + ), (1, genApp), (1, genTern) ) @@ -108,7 +128,12 @@ class CodeTest extends AnyFunSuite { Gen.frequency( (10, recX), (1, Gen.zip(recS, rec).map { case (s, r) => s.withValue(r) }), - (1, Gen.zip(genNel(4, cond), rec).map { case (conds, e) => Code.IfElse(conds, e) }) + ( + 1, + Gen.zip(genNel(4, cond), rec).map { case (conds, e) => + Code.IfElse(conds, e) + } + ) ) } @@ -118,11 +143,12 @@ class CodeTest extends AnyFunSuite { lst <- Gen.listOfN(cnt, genA) } yield NonEmptyList.fromListUnsafe(lst) - def genStatement(depth: Int): Gen[Code.Statement] = { val genZero = { val gp = Gen.const(Code.Pass) - val genImp = Gen.zip(genPy2Name, Gen.option(genIdent)).map { case (m, a) => Code.Import(m, a) } + val genImp = Gen.zip(genPy2Name, Gen.option(genIdent)).map { + case (m, a) => Code.Import(m, a) + } Gen.oneOf(gp, genImp) } @@ -151,8 +177,10 @@ class CodeTest extends AnyFunSuite { val genBlock = genNel(5, recStmt).map(Code.Block(_)) val genRet = recVL.map(Code.toReturn(_)) val genAlways = recVL.map(Code.always(_)) - val genAssign = Gen.zip(genIdent, recVL).map { case (v, e) => Code.addAssign(v, e) } - val genWhile = Gen.zip(recExpr, recStmt).map { case (c, b) => Code.While(c, b) } + val genAssign = + Gen.zip(genIdent, recVL).map { case (v, e) => Code.addAssign(v, e) } + val genWhile = + Gen.zip(recExpr, recStmt).map { case (c, b) => Code.While(c, b) } val genIf = for { conds <- genNel(4, Gen.zip(recExpr, recStmt)) @@ -160,11 +188,11 @@ class CodeTest extends AnyFunSuite { } yield Code.ifStatement(conds, elseCond) val genDef = - for { - name <- genIdent - args <- Gen.listOf(genIdent) - body <- recStmt - } yield Code.Def(name, args, body) + for { + name <- genIdent + args <- Gen.listOf(genIdent) + body <- recStmt + } yield Code.Def(name, args, body) Gen.frequency( (20, genZero), @@ -183,10 +211,15 @@ class CodeTest extends AnyFunSuite { def assertParse(str: String) = { try { - val mod = JythonBarrier.run(JythonParserFacade.parseExpressionOrModule(new java.io.StringReader(str), "filename.py", new org.python.core.CompilerFlags())) + val mod = JythonBarrier.run( + JythonParserFacade.parseExpressionOrModule( + new java.io.StringReader(str), + "filename.py", + new org.python.core.CompilerFlags() + ) + ) assert(mod != null) - } - catch { + } catch { case _: Throwable => val msg = "\n\n" + ("=" * 80) + "\n\n" + str + "\n\n" + ("=" * 80) assert(false, msg) @@ -219,19 +252,26 @@ else: test("test some Operator examples") { import Code._ - val apbpc = Op(Ident("a"), Const.Plus, Op(Ident("b"), Const.Plus, Ident("c"))) + val apbpc = + Op(Ident("a"), Const.Plus, Op(Ident("b"), Const.Plus, Ident("c"))) assert(toDoc(apbpc).renderTrim(80) == """a + b + c""") - val apbmc = Op(Ident("a"), Const.Plus, Op(Ident("b"), Const.Minus, Ident("c"))) + val apbmc = + Op(Ident("a"), Const.Plus, Op(Ident("b"), Const.Minus, Ident("c"))) assert(toDoc(apbmc).renderTrim(80) == """a + b - c""") - val ambmc = Op(Ident("a"), Const.Minus, Op(Ident("b"), Const.Minus, Ident("c"))) + val ambmc = + Op(Ident("a"), Const.Minus, Op(Ident("b"), Const.Minus, Ident("c"))) assert(toDoc(ambmc).renderTrim(80) == """a - (b - c)""") - val amzmbmc = Op(Op(Ident("a"), Const.Minus, Ident("z")), Const.Minus, Op(Ident("b"), Const.Minus, Ident("c"))) + val amzmbmc = Op( + Op(Ident("a"), Const.Minus, Ident("z")), + Const.Minus, + Op(Ident("b"), Const.Minus, Ident("c")) + ) assert(toDoc(amzmbmc).renderTrim(80) == """(a - z) - (b - c)""") } @@ -250,11 +290,9 @@ else: if (cmp == 0) { assert(p1.eval(Code.Const.Eq, p2) == Code.Const.True) - } - else if (cmp < 0) { + } else if (cmp < 0) { assert(p1.eval(Code.Const.Lt, p2) == Code.Const.True) - } - else { + } else { assert(p1.eval(Code.Const.Gt, p2) == Code.Const.True) } } @@ -302,17 +340,39 @@ else: val gop = Gen.oneOf(Code.Const.Plus, Code.Const.Minus, Code.Const.Times) forAll(gi, gi, gi, gop, gop) { (a, b, c, op1, op2) => - val left = Code.Op(Code.Op(Code.fromLong(a), op1, Code.fromLong(b)), op2, Code.fromLong(c)) - assert(left.simplify == Code.PyInt(op2(op1(BigInteger.valueOf(a), BigInteger.valueOf(b)), BigInteger.valueOf(c)))) + val left = Code.Op( + Code.Op(Code.fromLong(a), op1, Code.fromLong(b)), + op2, + Code.fromLong(c) + ) + assert( + left.simplify == Code.PyInt( + op2( + op1(BigInteger.valueOf(a), BigInteger.valueOf(b)), + BigInteger.valueOf(c) + ) + ) + ) - val right = Code.Op(Code.fromLong(a), op1, Code.Op(Code.fromLong(b), op2, Code.fromLong(c))) - assert(right.simplify == Code.PyInt(op1(BigInteger.valueOf(a), op2(BigInteger.valueOf(b), BigInteger.valueOf(c))))) + val right = Code.Op( + Code.fromLong(a), + op1, + Code.Op(Code.fromLong(b), op2, Code.fromLong(c)) + ) + assert( + right.simplify == Code.PyInt( + op1( + BigInteger.valueOf(a), + op2(BigInteger.valueOf(b), BigInteger.valueOf(c)) + ) + ) + ) } } def runAll(op: Code.Expression): Option[Code.PyInt] = op match { - case pi@Code.PyInt(_) => Some(pi) + case pi @ Code.PyInt(_) => Some(pi) case Code.Op(left, op: Code.IntOp, right) => for { l <- runAll(left) @@ -321,7 +381,11 @@ else: case _ => None } - def genOp(depth: Int, go: Gen[Code.IntOp], gen0: Gen[Code.Expression]): Gen[Code.Expression] = + def genOp( + depth: Int, + go: Gen[Code.IntOp], + gen0: Gen[Code.Expression] + ): Gen[Code.Expression] = if (depth <= 0) gen0 else { val rec = Gen.lzy(genIntOp(depth - 1, go)) @@ -334,14 +398,22 @@ else: def genIntOp(depth: Int, go: Gen[Code.IntOp]): Gen[Code.Expression] = genOp(depth, go, Gen.choose(-1024, 1024).map(Code.fromInt)) - test("any sequence of IntOps is optimized") { - forAll(genIntOp(5, Gen.oneOf(Code.Const.Plus, Code.Const.Minus, Code.Const.Times))) { op => + forAll( + genIntOp( + 5, + Gen.oneOf(Code.Const.Plus, Code.Const.Minus, Code.Const.Times) + ) + ) { op => // adding zero collapses to an Int assert(Some(op.evalPlus(Code.fromInt(0))) == runAll(op)) assert(Some(Code.fromInt(0).evalPlus(op)) == runAll(op)) assert(Some(op.evalMinus(Code.fromInt(0))) == runAll(op)) - assert(Some(Code.fromInt(0).evalMinus(op)) == runAll(op.evalTimes(Code.fromInt(-1)))) + assert( + Some(Code.fromInt(0).evalMinus(op)) == runAll( + op.evalTimes(Code.fromInt(-1)) + ) + ) assert(Some(Code.fromInt(1).evalTimes(op)) == runAll(op)) } } @@ -350,13 +422,21 @@ else: val gen = genOp( 5, Gen.oneOf(Code.Const.Plus, Code.Const.Minus), - Gen.oneOf(Gen.choose(-1024, 1024).map(Code.fromInt), Gen.identifier.map(Code.Ident(_)))) + Gen.oneOf( + Gen.choose(-1024, 1024).map(Code.fromInt), + Gen.identifier.map(Code.Ident(_)) + ) + ) forAll(gen) { op => val simpOp = op.simplify - def assertGood(x: Code.Expression, isRight: Boolean): org.scalatest.Assertion = + def assertGood( + x: Code.Expression, + isRight: Boolean + ): org.scalatest.Assertion = x match { - case Code.PyInt(_) => assert(isRight, s"found: $x on the left inside of $simpOp") + case Code.PyInt(_) => + assert(isRight, s"found: $x on the left inside of $simpOp") case Code.Op(left, _, right) => assertGood(left, false) assertGood(right, isRight) @@ -376,20 +456,20 @@ else: assert(block(Pass, Pass) == Pass) forAll(genNel(4, genStatement(3))) { case NonEmptyList(h, t) => - val stmt = block(h, t :_*) + val stmt = block(h, t: _*) def passCount(s: Statement): Int = s match { - case Pass => 1 + case Pass => 1 case Block(s) => s.toList.map(passCount).sum - case _ => 0 + case _ => 0 } def notPassCount(s: Statement): Int = s match { - case Pass => 0 + case Pass => 0 case Block(s) => s.toList.map(notPassCount).sum - case _ => 1 + case _ => 1 } val pc = passCount(stmt) @@ -410,7 +490,14 @@ else: val regressions: List[Code.Expression] = List( - Code.SelectItem(Code.Ternary(Code.fromInt(0), Code.fromInt(0), Code.MakeTuple(List(Code.fromInt(42)))), 0) + Code.SelectItem( + Code.Ternary( + Code.fromInt(0), + Code.fromInt(0), + Code.MakeTuple(List(Code.fromInt(42))) + ), + 0 + ) ) regressions.foreach { expr => @@ -425,15 +512,13 @@ else: case Code.PyBool(b) => if (b) { assert(tern == t.simplify) - } - else { + } else { assert(tern == f.simplify) } case Code.PyInt(i) => if (i != BigInteger.ZERO) { assert(tern == t.simplify) - } - else { + } else { assert(tern == f.simplify) } case whoKnows => @@ -446,7 +531,7 @@ else: forAll(genExpr(4)) { expr => expr.identOrParens match { case Code.Ident(_) | Code.Parens(_) => assert(true) - case other => assert(false, other.toString) + case other => assert(false, other.toString) } } } @@ -457,6 +542,10 @@ else: val and = left.evalAnd(right) assert(Code.toDoc(and).renderTrim(80) == "(a == b) and (b == c)") - assert(Code.toDoc(Code.Ident("z").evalAnd(and)).renderTrim(80) == "z and (a == b) and (b == c)") + assert( + Code + .toDoc(Code.Ident("z").evalAnd(and)) + .renderTrim(80) == "z and (a == b) and (b == c)" + ) } } diff --git a/cli/src/test/scala/org/bykn/bosatsu/codegen/python/PythonGenTest.scala b/cli/src/test/scala/org/bykn/bosatsu/codegen/python/PythonGenTest.scala index 4a75dedde..f62c52317 100644 --- a/cli/src/test/scala/org/bykn/bosatsu/codegen/python/PythonGenTest.scala +++ b/cli/src/test/scala/org/bykn/bosatsu/codegen/python/PythonGenTest.scala @@ -5,9 +5,19 @@ import cats.data.NonEmptyList import java.io.{ByteArrayInputStream, InputStream} import java.nio.file.{Paths, Files} import java.util.concurrent.Semaphore -import org.bykn.bosatsu.{PackageMap, MatchlessFromTypedExpr, Parser, Package, LocationMap, PackageName} +import org.bykn.bosatsu.{ + PackageMap, + MatchlessFromTypedExpr, + Parser, + Package, + LocationMap, + PackageName +} import org.scalacheck.Gen -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.python.util.PythonInterpreter import org.python.core.{PyInteger, PyFunction, PyObject, PyTuple} @@ -58,13 +68,21 @@ class PythonGenTest extends AnyFunSuite { tup.getArray()(0) match { case x if x == zero => // True == one in our encoding - assert(tup.getArray()(1) == one, prefix + "/" + tup.getArray()(2).toString) + assert( + tup.getArray()(1) == one, + prefix + "/" + tup.getArray()(2).toString + ) () case x if x == one => val suite = tup.getArray()(1).toString - foreachList(tup.getArray()(2)) { t => checkTest(t, prefix + "/" + suite); () } + foreachList(tup.getArray()(2)) { t => + checkTest(t, prefix + "/" + suite); () + } case other => - assert(false, s"expected a Test to have 0 or 1 in first tuple entry: $tup, $other") + assert( + false, + s"expected a Test to have 0 or 1 in first tuple entry: $tup, $other" + ) () } } @@ -73,7 +91,6 @@ class PythonGenTest extends AnyFunSuite { def toS(s: String): String = new String(Files.readAllBytes(Paths.get(s)), "UTF-8") - val packNEL = NonEmptyList(path, rest.toList) .map { s => @@ -85,7 +102,7 @@ class PythonGenTest extends AnyFunSuite { val res = PackageMap.typeCheckParsed(packNEL, Nil, "") res.left match { case Some(err) => sys.error(err.toString) - case None => () + case None => () } res.right.get @@ -102,7 +119,8 @@ class PythonGenTest extends AnyFunSuite { val bosatsuPM = compileFile(natPathBosatu) val matchless = MatchlessFromTypedExpr.compile(bosatsuPM) - val packMap = PythonGen.renderAll(matchless, Map.empty, Map.empty, Map.empty) + val packMap = + PythonGen.renderAll(matchless, Map.empty, Map.empty, Map.empty) val natDoc = packMap(PackageName.parts("Bosatsu", "Nat"))._2 JythonBarrier.run { @@ -121,8 +139,7 @@ class PythonGenTest extends AnyFunSuite { val res = fn.__call__(arg) if (i <= 0) { assert(res == new PyInteger(0)) - } - else { + } else { assert(fn.__call__(arg) == arg) } } @@ -143,48 +160,52 @@ class PythonGenTest extends AnyFunSuite { JythonBarrier.run(intr.close()) - def runBoTests(path: String, pn: PackageName, testName: String) = JythonBarrier.run { - val intr = new PythonInterpreter() - - val bosatsuPM = compileFile(path) - val matchless = MatchlessFromTypedExpr.compile(bosatsuPM) + def runBoTests(path: String, pn: PackageName, testName: String) = + JythonBarrier.run { + val intr = new PythonInterpreter() - val packMap = PythonGen.renderAll(matchless, Map.empty, Map.empty, Map.empty) - val doc = packMap(pn)._2 + val bosatsuPM = compileFile(path) + val matchless = MatchlessFromTypedExpr.compile(bosatsuPM) - intr.execfile(isfromString(doc.renderTrim(80)), "test.py") - checkTest(intr.get(testName), pn.asString) + val packMap = + PythonGen.renderAll(matchless, Map.empty, Map.empty, Map.empty) + val doc = packMap(pn)._2 - intr.close() - } + intr.execfile(isfromString(doc.renderTrim(80)), "test.py") + checkTest(intr.get(testName), pn.asString) + intr.close() + } test("we can compile StrConcatExample") { runBoTests( "test_workspace/StrConcatExample.bosatsu", PackageName.parts("StrConcatExample"), - "test") + "test" + ) } - test("test some list pattern matches") { runBoTests( "test_workspace/ListPat.bosatsu", PackageName.parts("ListPat"), - "tests") + "tests" + ) } test("test euler6") { runBoTests( "test_workspace/euler6.bosatsu", PackageName.parts("Euler", "P6"), - "tests") + "tests" + ) } test("test PredefTests") { runBoTests( "test_workspace/PredefTests.bosatsu", PackageName.parts("PredefTests"), - "test") + "test" + ) } } diff --git a/core/.js/src/main/scala/org/bykn/bosatsu/Par.scala b/core/.js/src/main/scala/org/bykn/bosatsu/Par.scala index 5744722d8..052bc4bd2 100644 --- a/core/.js/src/main/scala/org/bykn/bosatsu/Par.scala +++ b/core/.js/src/main/scala/org/bykn/bosatsu/Par.scala @@ -1,11 +1,10 @@ package org.bykn.bosatsu -/** - * This is an abstraction to handle parallel computation, not effectful - * computation. It is used in places where we have parallelism in expensive - * operations. Since scalajs cannot handle this, we use conditional build - * to replace the scalajs with just running directly - */ +/** This is an abstraction to handle parallel computation, not effectful + * computation. It is used in places where we have parallelism in expensive + * operations. Since scalajs cannot handle this, we use conditional build to + * replace the scalajs with just running directly + */ object Par { class Box[A] { private[this] var value: A = _ @@ -34,4 +33,3 @@ object Par { @inline def toF[A](pa: P[A]): F[A] = pa.get } - diff --git a/core/.jvm/src/main/scala/org/bykn/bosatsu/Par.scala b/core/.jvm/src/main/scala/org/bykn/bosatsu/Par.scala index 0c4fe41aa..5da6397ce 100644 --- a/core/.jvm/src/main/scala/org/bykn/bosatsu/Par.scala +++ b/core/.jvm/src/main/scala/org/bykn/bosatsu/Par.scala @@ -4,12 +4,11 @@ import java.util.concurrent.Executors import scala.concurrent.{Await, ExecutionContext, Future, Promise} import scala.concurrent.duration.Duration -/** - * This is an abstraction to handle parallel computation, not effectful - * computation. It is used in places where we have parallelism in expensive - * operations. Since scalajs cannot handle this, we use conditional build - * to replace the scalajs with just running directly - */ +/** This is an abstraction to handle parallel computation, not effectful + * computation. It is used in places where we have parallelism in expensive + * operations. Since scalajs cannot handle this, we use conditional build to + * replace the scalajs with just running directly + */ object Par { type F[A] = Future[A] type P[A] = Promise[A] @@ -22,7 +21,8 @@ object Par { def shutdownService(es: ExecutionService): Unit = es.shutdown() - def ecFromService(es: ExecutionService): EC = ExecutionContext.fromExecutor(es) + def ecFromService(es: ExecutionService): EC = + ExecutionContext.fromExecutor(es) @inline def start[A](a: => A)(implicit ec: EC): F[A] = Future(a) @@ -39,4 +39,3 @@ object Par { @inline def toF[A](pa: P[A]): F[A] = pa.future } - diff --git a/core/src/main/scala/org/bykn/bosatsu/BindingStatement.scala b/core/src/main/scala/org/bykn/bosatsu/BindingStatement.scala index ab37868f0..9078d3539 100644 --- a/core/src/main/scala/org/bykn/bosatsu/BindingStatement.scala +++ b/core/src/main/scala/org/bykn/bosatsu/BindingStatement.scala @@ -1,16 +1,18 @@ package org.bykn.bosatsu -import org.typelevel.paiges.{ Doc, Document } +import org.typelevel.paiges.{Doc, Document} case class BindingStatement[B, V, T](name: B, value: V, in: T) object BindingStatement { private[this] val eqDoc = Doc.text(" = ") - implicit def document[A: Document, V: Document, T: Document]: Document[BindingStatement[A, V, T]] = + implicit def document[A: Document, V: Document, T: Document] + : Document[BindingStatement[A, V, T]] = Document.instance[BindingStatement[A, V, T]] { let => import let._ - Document[A].document(name) + eqDoc + Document[V].document(value) + Document[T].document(in) + Document[A].document(name) + eqDoc + Document[V].document( + value + ) + Document[T].document(in) } } - diff --git a/core/src/main/scala/org/bykn/bosatsu/CollectionUtils.scala b/core/src/main/scala/org/bykn/bosatsu/CollectionUtils.scala index 0a3ce7b75..5f87819a5 100644 --- a/core/src/main/scala/org/bykn/bosatsu/CollectionUtils.scala +++ b/core/src/main/scala/org/bykn/bosatsu/CollectionUtils.scala @@ -8,10 +8,13 @@ import scala.util.{Success, Failure, Try} import cats.implicits._ object CollectionUtils { - /** - * Return the unique keys on the Right, and the duplicate keys on the Left (and possibly Both) - */ - def uniqueByKey[A, B: Order](as: NonEmptyList[A])(fn: A => B): Ior[NonEmptyMap[B, (A, NonEmptyList[A])], NonEmptyMap[B, A]] = { + + /** Return the unique keys on the Right, and the duplicate keys on the Left + * (and possibly Both) + */ + def uniqueByKey[A, B: Order](as: NonEmptyList[A])( + fn: A => B + ): Ior[NonEmptyMap[B, (A, NonEmptyList[A])], NonEmptyMap[B, A]] = { def check(as: NonEmptyList[A]): Either[(A, NonEmptyList[A]), A] = as match { case NonEmptyList(a, Nil) => @@ -21,13 +24,16 @@ object CollectionUtils { } // We know this is nonEmpty, so good and bad can't both be empty - val checked: SortedMap[B, Either[(A, NonEmptyList[A]), A]] = as.groupBy(fn).map { case (b, as) => (b, check(as)) } + val checked: SortedMap[B, Either[(A, NonEmptyList[A]), A]] = + as.groupBy(fn).map { case (b, as) => (b, check(as)) } val good: SortedMap[B, A] = checked.collect { case (b, Right(a)) => (b, a) } - val bad: SortedMap[B, (A, NonEmptyList[A])] = checked.collect { case (b, Left(a)) => (b, a) } + val bad: SortedMap[B, (A, NonEmptyList[A])] = checked.collect { + case (b, Left(a)) => (b, a) + } (NonEmptyMap.fromMap(bad), NonEmptyMap.fromMap(good)) match { - case (None, Some(goodNE)) => Ior.right(goodNE) - case (Some(badNE), None) => Ior.left(badNE) + case (None, Some(goodNE)) => Ior.right(goodNE) + case (Some(badNE), None) => Ior.left(badNE) case (Some(badNE), Some(goodNE)) => Ior.both(badNE, goodNE) // $COVERAGE-OFF$ case _ => @@ -36,17 +42,20 @@ object CollectionUtils { } } - def listToUnique[A, K: Order, V](l: List[A])(key: A => K, value: A => V, msg: => String): Try[SortedMap[K, V]] = + def listToUnique[A, K: Order, V]( + l: List[A] + )(key: A => K, value: A => V, msg: => String): Try[SortedMap[K, V]] = NonEmptyList.fromList(l) match { case None => Success(SortedMap.empty[K, V]) case Some(nel) => uniqueByKey(nel)(key) match { - case Ior.Right(b) => Success(b.toSortedMap.map { case (k, a) => (k, value(a)) }) + case Ior.Right(b) => + Success(b.toSortedMap.map { case (k, a) => (k, value(a)) }) case Ior.Left(errMap) => Failure(new IllegalArgumentException(s"$msg, $errMap")) case Ior.Both(errMap, _) => Failure(new IllegalArgumentException(s"$msg, $errMap")) - } + } } } diff --git a/core/src/main/scala/org/bykn/bosatsu/CommentStatement.scala b/core/src/main/scala/org/bykn/bosatsu/CommentStatement.scala index e9b3b34e1..102f00326 100644 --- a/core/src/main/scala/org/bykn/bosatsu/CommentStatement.scala +++ b/core/src/main/scala/org/bykn/bosatsu/CommentStatement.scala @@ -2,12 +2,10 @@ package org.bykn.bosatsu import cats.data.NonEmptyList import cats.parse.{Parser0 => P0, Parser => P} -import org.typelevel.paiges.{ Doc, Document } +import org.typelevel.paiges.{Doc, Document} -/** - * Represents a commented thing. Commented[A] would probably - * be a better name - */ +/** Represents a commented thing. Commented[A] would probably be a better name + */ final case class CommentStatement[T](message: NonEmptyList[String], on: T) object CommentStatement { @@ -16,20 +14,24 @@ object CommentStatement { implicit def document[T: Document]: Document[CommentStatement[T]] = Document.instance[CommentStatement[T]] { comment => import comment._ - val block = Doc.intercalate(Doc.line, message.toList.map { mes => Doc.char('#') + Doc.text(mes) }) + val block = Doc.intercalate( + Doc.line, + message.toList.map { mes => Doc.char('#') + Doc.text(mes) } + ) block + Doc.line + Document[T].document(on) } - /** on should make sure indent is matching - * this is to allow a P[Unit] that does nothing for testing or other applications - */ + /** on should make sure indent is matching this is to allow a P[Unit] that + * does nothing for testing or other applications + */ def parser[T](onP: String => P0[T]): Parser.Indy[CommentStatement[T]] = Parser.Indy { indent => val sep = Parser.newline val commentBlock: P[NonEmptyList[String]] = // if the next line is part of the comment until we see the # or not - (Parser.maybeSpace.with1.soft *> commentPart).repSep(sep = sep) <* Parser.newline.orElse(P.end) + (Parser.maybeSpace.with1.soft *> commentPart) + .repSep(sep = sep) <* Parser.newline.orElse(P.end) (commentBlock ~ onP(indent)) .map { case (m, on) => CommentStatement(m, on) } @@ -38,5 +40,3 @@ object CommentStatement { val commentPart: P[String] = P.char('#') *> P.until0(P.char('\n')) } - - diff --git a/core/src/main/scala/org/bykn/bosatsu/Declaration.scala b/core/src/main/scala/org/bykn/bosatsu/Declaration.scala index 8c81bf11c..e566b96f6 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Declaration.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Declaration.scala @@ -1,10 +1,19 @@ package org.bykn.bosatsu -import Parser.{ Combinators, Indy, maybeSpace, maybeSpacesAndLines, spaces, toEOL1, keySpace, MaybeTupleOrParens } +import Parser.{ + Combinators, + Indy, + maybeSpace, + maybeSpacesAndLines, + spaces, + toEOL1, + keySpace, + MaybeTupleOrParens +} import cats.data.NonEmptyList import org.bykn.bosatsu.graph.Memoize import cats.parse.{Parser0 => P0, Parser => P} -import org.typelevel.paiges.{ Doc, Document } +import org.typelevel.paiges.{Doc, Document} import scala.collection.immutable.SortedSet import Indy.IndyMethods @@ -14,9 +23,9 @@ import ListLang.{KVPair, SpliceOrItem} import Identifier.{Bindable, Constructor} import cats.implicits._ -/** - * Represents the syntactic version of Expr - */ + +/** Represents the syntactic version of Expr + */ sealed abstract class Declaration { import Declaration._ @@ -44,32 +53,48 @@ sealed abstract class Declaration { (args.head.toDoc + Doc.char('.') + fnDoc, args.tail) } - prefix + Doc.char('(') + Doc.intercalate(Doc.text(", "), body.map(_.toDoc)) + Doc.char(')') + prefix + Doc.char('(') + Doc.intercalate( + Doc.text(", "), + body.map(_.toDoc) + ) + Doc.char(')') case ApplyOp(left, Identifier.Operator(opStr), right) => left.toDoc space Doc.text(opStr) space right.toDoc case Binding(b) => val d0 = Document[Padding[Declaration]] val withNewLine = Document.instance[Padding[Declaration]] { pd => - Doc.line + d0.document(pd) + Doc.line + d0.document(pd) } - BindingStatement.document(Document[Pattern.Parsed], Document.instance[NonBinding](_.toDoc), withNewLine).document(b) + BindingStatement + .document( + Document[Pattern.Parsed], + Document.instance[NonBinding](_.toDoc), + withNewLine + ) + .document(b) case LeftApply(pat, _, arg, body) => - Document[Pattern.Parsed].document(pat) + Doc.text(" <- ") + arg.toDoc + Doc.line + - Document[Padding[Declaration]].document(body) + Document[Pattern.Parsed].document(pat) + Doc.text( + " <- " + ) + arg.toDoc + Doc.line + + Document[Padding[Declaration]].document(body) case Comment(c) => CommentStatement.document[Padding[Declaration]].document(c) case CommentNB(c) => CommentStatement.document[Padding[NonBinding]].document(c) case DefFn(d) => - implicit val pairDoc: Document[(OptIndent[Declaration], Padding[Declaration])] = - Document.instance { - case (fnBody, letBody) => - fnBody.sepDoc + - Document[OptIndent[Declaration]].document(fnBody) + - Doc.line + - Document[Padding[Declaration]].document(letBody) + implicit val pairDoc + : Document[(OptIndent[Declaration], Padding[Declaration])] = + Document.instance { case (fnBody, letBody) => + fnBody.sepDoc + + Document[OptIndent[Declaration]].document(fnBody) + + Doc.line + + Document[Padding[Declaration]].document(letBody) } - DefStatement.document[Pattern.Parsed, (OptIndent[Declaration], Padding[Declaration])].document(d) + DefStatement + .document[ + Pattern.Parsed, + (OptIndent[Declaration], Padding[Declaration]) + ] + .document(d) case IfElse(ifCases, elseCase) => def checkBody(cb: (Declaration, OptIndent[Declaration])): Doc = { val (check, optbody) = cb @@ -78,16 +103,22 @@ sealed abstract class Declaration { check.toDoc + Doc.char(':') + rest } - val elseDoc = elseCase.sepDoc + Document[OptIndent[Declaration]].document(elseCase) + val elseDoc = + elseCase.sepDoc + Document[OptIndent[Declaration]].document(elseCase) val tail = Doc.text("else:") + elseDoc :: Nil - val parts = (Doc.text("if ") + checkBody(ifCases.head)) :: (ifCases.tail.map(Doc.text("elif ") + checkBody(_))) ::: tail + val parts = (Doc.text("if ") + checkBody(ifCases.head)) :: (ifCases.tail + .map(Doc.text("elif ") + checkBody(_))) ::: tail Doc.intercalate(Doc.line, parts) case Ternary(trueCase, cond, falseCase) => - Doc.intercalate(Doc.space, - trueCase.toDoc :: Doc.text("if") :: cond.toDoc :: Doc.text("else") :: falseCase.toDoc :: Nil) + Doc.intercalate( + Doc.space, + trueCase.toDoc :: Doc.text("if") :: cond.toDoc :: Doc.text( + "else" + ) :: falseCase.toDoc :: Nil + ) case Lambda(args, body) => // slash style: - //val argDoc = Doc.char('\\') + Doc.intercalate(Doc.text(", "), args.toList.map(Document[Pattern.Parsed].document(_))) + // val argDoc = Doc.char('\\') + Doc.intercalate(Doc.text(", "), args.toList.map(Document[Pattern.Parsed].document(_))) // bare style: val argDoc = args match { case NonEmptyList(one, Nil) => @@ -95,11 +126,13 @@ sealed abstract class Declaration { if (Pattern.isNonUnitTuple(one)) { // wrap with parens Doc.char('(') + od + Doc.char(')') - } - else od + } else od case args => // more than one must wrap in () - Doc.char('(') + Doc.intercalate(Doc.text(", "), args.toList.map(Document[Pattern.Parsed].document(_))) + Doc.char(')') + Doc.char('(') + Doc.intercalate( + Doc.text(", "), + args.toList.map(Document[Pattern.Parsed].document(_)) + ) + Doc.char(')') } argDoc + Doc.text(" -> ") + body.toDoc case Literal(lit) => Document[Lit].document(lit) @@ -108,26 +141,34 @@ sealed abstract class Declaration { val caseDoc = Doc.text("case ") - implicit val patDoc: Document[(Pattern.Parsed, OptIndent[Declaration])] = + implicit val patDoc + : Document[(Pattern.Parsed, OptIndent[Declaration])] = Document.instance[(Pattern.Parsed, OptIndent[Declaration])] { case (pat, decl) => - caseDoc + Document[Pattern.Parsed].document(pat) + Doc.text(":") + decl.sepDoc + pid.document(decl) + caseDoc + Document[Pattern.Parsed].document(pat) + Doc.text( + ":" + ) + decl.sepDoc + pid.document(decl) } implicit def linesDoc[T: Document]: Document[NonEmptyList[T]] = - Document.instance { ts => Doc.intercalate(Doc.line, ts.toList.map(Document[T].document _)) } + Document.instance { ts => + Doc.intercalate(Doc.line, ts.toList.map(Document[T].document _)) + } - val piPat = Document[OptIndent[NonEmptyList[(Pattern.Parsed, OptIndent[Declaration])]]] + val piPat = Document[OptIndent[ + NonEmptyList[(Pattern.Parsed, OptIndent[Declaration])] + ]] val kindDoc = kind match { case RecursionKind.NonRecursive => Doc.text("match ") - case RecursionKind.Recursive => Doc.text("recur ") + case RecursionKind.Recursive => Doc.text("recur ") } // TODO this isn't quite right kindDoc + typeName.toDoc + Doc.char(':') + args.sepDoc + piPat.document(args) - case m@Matches(arg, p) => + case m @ Matches(arg, p) => val da = arg match { // matches binds tighter than all these - case Lambda(_, _) | IfElse(_, _) | ApplyOp(_, _, _) | Match(_, _, _) => + case Lambda(_, _) | IfElse(_, _) | ApplyOp(_, _, _) | + Match(_, _, _) => Parens(arg)(m.region).toDoc case _ => arg.toDoc @@ -139,22 +180,26 @@ sealed abstract class Declaration { // we need a trailing comma here: Doc.char('(') + h.toDoc + Doc.char(',') + Doc.char(')') case TupleCons(items) => - Doc.char('(') + Doc.intercalate(Doc.text(", "), - items.map(_.toDoc)) + Doc.char(')') + Doc.char('(') + Doc.intercalate( + Doc.text(", "), + items.map(_.toDoc) + ) + Doc.char(')') case Var(name) => Document[Identifier].document(name) case StringDecl(parts) => val useDouble = parts.exists { case Right((_, str)) => str.contains('\'') && !str.contains('"') - case Left(_) => false + case Left(_) => false } val q = if (useDouble) '"' else '\'' - val inner = Doc.intercalate(Doc.empty, + val inner = Doc.intercalate( + Doc.empty, parts.toList.map { case Right((_, str)) => Doc.text(StringUtil.escape(q, str)) - case Left(decl) => Doc.text("${") + decl.toDoc + Doc.char('}') - }) + case Left(decl) => Doc.text("${") + decl.toDoc + Doc.char('}') + } + ) Doc.char(q) + inner + Doc.char(q) case ListDecl(list) => @@ -165,24 +210,28 @@ sealed abstract class Declaration { case RecordConstructor(name, args) => val argDoc = Doc.char('{') + - Doc.intercalate(Doc.char(',') + Doc.space, - args.toList.map(_.toDoc)) + Doc.char('}') + Doc.intercalate( + Doc.char(',') + Doc.space, + args.toList.map(_.toDoc) + ) + Doc.char('}') Declaration.identDoc.document(name) + Doc.space + argDoc } - /** - * Get the set of free variables in this declaration. - * These are variables that must be defined at an outer - * lexical scope in order to typecheck - */ + /** Get the set of free variables in this declaration. These are variables + * that must be defined at an outer lexical scope in order to typecheck + */ def freeVars: SortedSet[Bindable] = { - def loop(decl: Declaration, bound: Set[Bindable], acc: SortedSet[Bindable]): SortedSet[Bindable] = + def loop( + decl: Declaration, + bound: Set[Bindable], + acc: SortedSet[Bindable] + ): SortedSet[Bindable] = decl match { case Annotation(term, _) => loop(term, bound, acc) case Apply(fn, args, _) => (fn :: args).foldLeft(acc) { (acc0, d) => loop(d, bound, acc0) } - case ao@ApplyOp(left, _, right) => + case ao @ ApplyOp(left, _, right) => val acc0 = loop(left, bound, acc) val acc1 = loop(ao.opVar, bound, acc0) loop(right, bound, acc1) @@ -190,7 +239,7 @@ sealed abstract class Declaration { val acc0 = loop(v, bound, acc) val bound1 = bound ++ n.names loop(in.padded, bound1, acc0) - case Comment(c) => loop(c.on.padded, bound, acc) + case Comment(c) => loop(c.on.padded, bound, acc) case CommentNB(c) => loop(c.on.padded, bound, acc) case DefFn(d) => val (body, rest) = d.result @@ -215,7 +264,7 @@ sealed abstract class Declaration { case Lambda(args, body) => val bound1 = bound ++ args.patternNames loop(body, bound1, acc) - case la@LeftApply(_, _, _, _) => + case la @ LeftApply(_, _, _, _) => loop(la.rewrite, bound, acc) case Literal(_) => acc case Match(_, typeName, args) => @@ -230,11 +279,11 @@ sealed abstract class Declaration { case TupleCons(items) => items.foldLeft(acc) { (acc0, d) => loop(d, bound, acc0) } case Var(name: Bindable) if !bound(name) => acc + name - case Var(_) => acc + case Var(_) => acc case StringDecl(items) => items.foldLeft(acc) { case (acc, Left(nb)) => loop(nb, bound, acc) - case (acc, _) => acc + case (acc, _) => acc } case ListDecl(ListLang.Cons(items)) => items.foldLeft(acc) { (acc0, sori) => @@ -275,25 +324,23 @@ sealed abstract class Declaration { decl match { case Var(_) | Literal(_) => true case Annotation(term, _) => loop(term) - case Parens(p) => loop(p) - case _ => false + case Parens(p) => loop(p) + case _ => false } loop(this) } - /** - * Wrap in Parens is needed - */ + /** Wrap in Parens is needed + */ def toNonBinding: NonBinding = this match { case nb: NonBinding => nb - case decl => Parens(decl)(decl.region) + case decl => Parens(decl)(decl.region) } - /** - * This returns *all* names in the declaration, bound or not - */ + /** This returns *all* names in the declaration, bound or not + */ def allNames: SortedSet[Bindable] = { def loop(decl: Declaration, acc: SortedSet[Bindable]): SortedSet[Bindable] = decl match { @@ -306,9 +353,9 @@ sealed abstract class Declaration { case Binding(BindingStatement(n, v, in)) => val acc0 = loop(v, acc ++ n.names) loop(in.padded, acc0) - case Comment(c) => loop(c.on.padded, acc) + case Comment(c) => loop(c.on.padded, acc) case CommentNB(c) => loop(c.on.padded, acc) - case DefFn(d) => + case DefFn(d) => // def sets up a binding to itself, which // may or may not be recursive val acc1 = (acc + d.name) ++ d.args.toList.flatMap(_.patternNames) @@ -321,7 +368,7 @@ sealed abstract class Declaration { loop(v.get, acc1) } loop(elseCase.get, acc2) - case la@LeftApply(_, _, _, _) => + case la @ LeftApply(_, _, _, _) => loop(la.rewrite, acc) case Ternary(t, c, f) => val acc1 = loop(t, acc) @@ -342,11 +389,11 @@ sealed abstract class Declaration { case TupleCons(items) => items.foldLeft(acc) { (acc0, d) => loop(d, acc0) } case Var(name: Bindable) => acc + name - case Var(_) => acc + case Var(_) => acc case StringDecl(nel) => nel.foldLeft(acc) { case (acc0, Left(decl)) => loop(decl, acc0) - case (acc0, Right(_)) => acc0 + case (acc0, Right(_)) => acc0 } case ListDecl(ListLang.Cons(items)) => items.foldLeft(acc) { (acc0, sori) => @@ -369,7 +416,7 @@ sealed abstract class Declaration { case RecordConstructor(_, args) => args.foldLeft(acc) { case (acc, RecordArg.Pair(_, v)) => loop(v, acc) - case (acc, RecordArg.Simple(n)) => acc + n + case (acc, RecordArg.Simple(n)) => acc + n } } loop(this, SortedSet.empty) @@ -378,11 +425,24 @@ sealed abstract class Declaration { def replaceRegions(r: Region): Declaration = this match { case Binding(BindingStatement(n, v, in)) => - Binding(BindingStatement(n, v.replaceRegionsNB(r), in.map(_.replaceRegions(r))))(r) + Binding( + BindingStatement( + n, + v.replaceRegionsNB(r), + in.map(_.replaceRegions(r)) + ) + )(r) case Comment(CommentStatement(lines, c)) => Comment(CommentStatement(lines, c.map(_.replaceRegions(r))))(r) case DefFn(d) => - DefFn(d.copy(result = (d.result._1.map(_.replaceRegions(r)), d.result._2.map(_.replaceRegions(r)))))(r) + DefFn( + d.copy(result = + ( + d.result._1.map(_.replaceRegions(r)), + d.result._2.map(_.replaceRegions(r)) + ) + ) + )(r) case LeftApply(p, _, right, b) => LeftApply(p, r, right.replaceRegionsNB(r), b.map(_.replaceRegions(r))) case nb: NonBinding => nb.replaceRegionsNB(r) @@ -390,7 +450,8 @@ sealed abstract class Declaration { } object Declaration { - implicit val document: Document[Declaration] = Document.instance[Declaration](_.toDoc) + implicit val document: Document[Declaration] = + Document.instance[Declaration](_.toDoc) implicit val hasRegion: HasRegion[Declaration] = HasRegion.instance[Declaration](_.region) @@ -400,16 +461,18 @@ object Declaration { case object Parens extends ApplyKind } - /** - * Try to substitute ex for ident in the expression: in - * - * This can fail if the free variables in ex are shadowed - * above ident in in. - * - * this code is very similar to TypedExpr.substitute - * if bugs are found in one, consult the other - */ - def substitute[A](ident: Bindable, ex: NonBinding, in: Declaration): Option[Declaration] = { + /** Try to substitute ex for ident in the expression: in + * + * This can fail if the free variables in ex are shadowed above ident in in. + * + * this code is very similar to TypedExpr.substitute if bugs are found in + * one, consult the other + */ + def substitute[A]( + ident: Bindable, + ex: NonBinding, + in: Declaration + ): Option[Declaration] = { // if we hit a shadow, we don't need to substitute down // that branch @inline def shadows(i: Bindable): Boolean = i === ident @@ -418,10 +481,13 @@ object Declaration { // this causes us to return None lazy val masks: Bindable => Boolean = ex.freeVars - def loopLL[F[_]](ll: ListLang[F, NonBinding, Pattern.Parsed])(fn: F[NonBinding] => Option[F[NonBinding]]): Option[ListLang[F, NonBinding, Pattern.Parsed]] = + def loopLL[F[_]](ll: ListLang[F, NonBinding, Pattern.Parsed])( + fn: F[NonBinding] => Option[F[NonBinding]] + ): Option[ListLang[F, NonBinding, Pattern.Parsed]] = ll match { case ListLang.Cons(items) => - items.traverse(fn) + items + .traverse(fn) .map(ListLang.Cons(_)) case ListLang.Comprehension(ex, b, in, filt) => // b sets up bindings for filt and ex @@ -429,7 +495,8 @@ object Declaration { .flatMap { in1 => val pnames = b.names if (pnames.exists(masks)) None - else if (pnames.exists(shadows)) Some(ListLang.Comprehension(ex, b, in1, filt)) + else if (pnames.exists(shadows)) + Some(ListLang.Comprehension(ex, b, in1, filt)) else { // no shadowing or masking (fn(ex), filt.traverse(loop)) @@ -452,8 +519,7 @@ object Declaration { // This is no longer a simple RecordArg Some(RecordArg.Pair(fn, ex)) } - } - else Some(ra) + } else Some(ra) } def loop(decl: NonBinding): Option[NonBinding] = @@ -463,7 +529,7 @@ object Declaration { case Apply(fn, args, kind) => (loop(fn), args.traverse(loop)) .mapN(Apply(_, _, kind)(decl.region)) - case aop@ApplyOp(left, op, right) if (op: Bindable) === ident => + case aop @ ApplyOp(left, op, right) if (op: Bindable) === ident => // we cannot make a general substition on ApplyOp ex match { case Var(op1: Identifier.Operator) => @@ -494,7 +560,7 @@ object Declaration { if (pnames.exists(masks)) None else if (pnames.exists(shadows)) Some(decl) else loopDec(body).map(Lambda(args, _)(decl.region)) - case l@Literal(_) => Some(l) + case l @ Literal(_) => Some(l) case Match(k, arg, cases) => val caseRes = cases @@ -524,7 +590,7 @@ object Declaration { nel .traverse { case Left(nb) => loop(nb).map(Left(_)) - case right => Some(right) + case right => Some(right) } .map(StringDecl(_)(decl.region)) case ListDecl(ll) => @@ -534,7 +600,8 @@ object Declaration { loopLL(ll)(_.traverse(loop)) .map(DictDecl(_)(decl.region)) case RecordConstructor(c, args) => - args.traverse(loopRA) + args + .traverse(loopRA) .map(RecordConstructor(c, _)(decl.region)) } @@ -549,8 +616,7 @@ object Declaration { .map { v1 => Binding(BindingStatement(n, v1, in))(decl.region) } - } - else { + } else { // we substitute on both (loop(v), in.traverse(loopDec)) .mapN { (v1, in1) => @@ -585,8 +651,7 @@ object Declaration { .map { v1 => LeftApply(n, r, v1, in) } - } - else { + } else { // we substitute on both (loop(v), in.traverse(loopDec)) .mapN { (v1, in1) => @@ -629,16 +694,15 @@ object Declaration { (Identifier.bindableParser ~ (pairFn.?)) .map { - case (b, None) => Simple(b) + case (b, None) => Simple(b) case (b, Some(fn)) => fn(b) } } } - /** - * These are all Declarations other than Binding, DefFn and Comment, - * in other words, things that don't need to start with indentation - */ + /** These are all Declarations other than Binding, DefFn and Comment, in other + * words, things that don't need to start with indentation + */ sealed abstract class NonBinding extends Declaration { def replaceRegionsNB(r: Region): NonBinding = this match { @@ -650,17 +714,29 @@ object Declaration { case CommentNB(CommentStatement(msg, p)) => CommentNB(CommentStatement(msg, p.map(_.replaceRegionsNB(r))))(r) case IfElse(ifCases, elseCase) => - IfElse(ifCases.map { case (bool, res) => (bool.replaceRegionsNB(r), res.map(_.replaceRegions(r))) }, - elseCase.map(_.replaceRegions(r)))(r) + IfElse( + ifCases.map { case (bool, res) => + (bool.replaceRegionsNB(r), res.map(_.replaceRegions(r))) + }, + elseCase.map(_.replaceRegions(r)) + )(r) case Ternary(t, c, f) => - Ternary(t.replaceRegionsNB(r), c.replaceRegionsNB(r), f.replaceRegionsNB(r)) + Ternary( + t.replaceRegionsNB(r), + c.replaceRegionsNB(r), + f.replaceRegionsNB(r) + ) case Lambda(args, body) => Lambda(args, body.replaceRegions(r))(r) case Literal(lit) => Literal(lit)(r) case Match(rec, arg, branches) => - Match(rec, + Match( + rec, arg.replaceRegionsNB(r), - branches.map(_.map { case (p, x) => (p, x.map(_.replaceRegions(r))) }))(r) + branches.map(_.map { case (p, x) => + (p, x.map(_.replaceRegions(r))) + }) + )(r) case Matches(a, p) => Matches(a.replaceRegionsNB(r), p)(r) case Parens(p) => Parens(p.replaceRegions(r))(r) @@ -670,24 +746,38 @@ object Declaration { case StringDecl(nel) => val ne1 = nel.map { case Right((_, s)) => Right((r, s)) - case Left(e) => Left(e.replaceRegionsNB(r)) + case Left(e) => Left(e.replaceRegionsNB(r)) } StringDecl(ne1)(r) case ListDecl(ListLang.Cons(items)) => ListDecl(ListLang.Cons(items.map(_.map(_.replaceRegionsNB(r)))))(r) case ListDecl(ListLang.Comprehension(ex, b, in, filter)) => - ListDecl(ListLang.Comprehension(ex.map(_.replaceRegionsNB(r)), b, in.replaceRegionsNB(r), filter.map(_.replaceRegionsNB(r))))(r) + ListDecl( + ListLang.Comprehension( + ex.map(_.replaceRegionsNB(r)), + b, + in.replaceRegionsNB(r), + filter.map(_.replaceRegionsNB(r)) + ) + )(r) case DictDecl(ListLang.Cons(items)) => - DictDecl(ListLang.Cons(items.map { - case ListLang.KVPair(k, v) => - ListLang.KVPair(k.replaceRegionsNB(r), v.replaceRegionsNB(r)) + DictDecl(ListLang.Cons(items.map { case ListLang.KVPair(k, v) => + ListLang.KVPair(k.replaceRegionsNB(r), v.replaceRegionsNB(r)) }))(r) case DictDecl(ListLang.Comprehension(ex, b, in, filter)) => - DictDecl(ListLang.Comprehension(ex.map(_.replaceRegionsNB(r)), b, in.replaceRegionsNB(r), filter.map(_.replaceRegionsNB(r))))(r) + DictDecl( + ListLang.Comprehension( + ex.map(_.replaceRegionsNB(r)), + b, + in.replaceRegionsNB(r), + filter.map(_.replaceRegionsNB(r)) + ) + )(r) case RecordConstructor(c, args) => val args1 = args.map { case RecordArg.Simple(b) => RecordArg.Simple(b) - case RecordArg.Pair(k, v) => RecordArg.Pair(k, v.replaceRegionsNB(r)) + case RecordArg.Pair(k, v) => + RecordArg.Pair(k, v.replaceRegionsNB(r)) } RecordConstructor(c, args1)(r) } @@ -698,16 +788,35 @@ object Declaration { Document.instance(_.toDoc) } - /** - * These are "binding" kinds, (not-NonBinding) + /** These are "binding" kinds, (not-NonBinding) */ - case class Binding(binding: BindingStatement[Pattern.Parsed, NonBinding, Padding[Declaration]])(implicit val region: Region) extends Declaration - case class Comment(comment: CommentStatement[Padding[Declaration]])(implicit val region: Region) extends Declaration - case class DefFn(deffn: DefStatement[Pattern.Parsed, (OptIndent[Declaration], Padding[Declaration])])(implicit val region: Region) extends Declaration - case class LeftApply(arg: Pattern.Parsed, argRegion: Region, fn: NonBinding, result: Padding[Declaration]) extends Declaration { + case class Binding( + binding: BindingStatement[Pattern.Parsed, NonBinding, Padding[ + Declaration + ]] + )(implicit val region: Region) + extends Declaration + case class Comment(comment: CommentStatement[Padding[Declaration]])(implicit + val region: Region + ) extends Declaration + case class DefFn( + deffn: DefStatement[ + Pattern.Parsed, + (OptIndent[Declaration], Padding[Declaration]) + ] + )(implicit val region: Region) + extends Declaration + case class LeftApply( + arg: Pattern.Parsed, + argRegion: Region, + fn: NonBinding, + result: Padding[Declaration] + ) extends Declaration { def region: Region = argRegion + result.padded.region def rewrite: NonBinding = { - val lam = Lambda(NonEmptyList.one(arg), result.padded)(argRegion + result.padded.region) + val lam = Lambda(NonEmptyList.one(arg), result.padded)( + argRegion + result.padded.region + ) Apply(fn, NonEmptyList.one(lam), ApplyKind.Parens)(region) } } @@ -720,82 +829,124 @@ object Declaration { // value in tests and construct them. // These reasons are a bit abusive, and we may revisit this in the future // - case class Annotation(fn: NonBinding, tpe: TypeRef)(implicit val region: Region) extends NonBinding - case class Apply(fn: NonBinding, args: NonEmptyList[NonBinding], kind: ApplyKind)(implicit val region: Region) extends NonBinding - case class ApplyOp(left: NonBinding, op: Identifier.Operator, right: NonBinding) extends NonBinding { + case class Annotation(fn: NonBinding, tpe: TypeRef)(implicit + val region: Region + ) extends NonBinding + case class Apply( + fn: NonBinding, + args: NonEmptyList[NonBinding], + kind: ApplyKind + )(implicit val region: Region) + extends NonBinding + case class ApplyOp( + left: NonBinding, + op: Identifier.Operator, + right: NonBinding + ) extends NonBinding { val region = left.region + right.region def opVar: Var = Var(op)(Region(left.region.end, right.region.start)) def toApply: Apply = Apply(opVar, NonEmptyList(left, right :: Nil), ApplyKind.Parens)(region) } - case class CommentNB(comment: CommentStatement[Padding[NonBinding]])(implicit val region: Region) extends NonBinding - - case class IfElse(ifCases: NonEmptyList[(NonBinding, OptIndent[Declaration])], - elseCase: OptIndent[Declaration])(implicit val region: Region) extends NonBinding - case class Ternary(trueCase: NonBinding, cond: NonBinding, falseCase: NonBinding) extends NonBinding { + case class CommentNB(comment: CommentStatement[Padding[NonBinding]])(implicit + val region: Region + ) extends NonBinding + + case class IfElse( + ifCases: NonEmptyList[(NonBinding, OptIndent[Declaration])], + elseCase: OptIndent[Declaration] + )(implicit val region: Region) + extends NonBinding + case class Ternary( + trueCase: NonBinding, + cond: NonBinding, + falseCase: NonBinding + ) extends NonBinding { val region = trueCase.region + falseCase.region } - case class Lambda(args: NonEmptyList[Pattern.Parsed], body: Declaration)(implicit val region: Region) extends NonBinding + case class Lambda(args: NonEmptyList[Pattern.Parsed], body: Declaration)( + implicit val region: Region + ) extends NonBinding case class Literal(lit: Lit)(implicit val region: Region) extends NonBinding case class Match( - kind: RecursionKind, - arg: NonBinding, - cases: OptIndent[NonEmptyList[(Pattern.Parsed, OptIndent[Declaration])]])( - implicit val region: Region) extends NonBinding - case class Matches(arg: NonBinding, pattern: Pattern.Parsed)(implicit val region: Region) extends NonBinding - case class Parens(of: Declaration)(implicit val region: Region) extends NonBinding - case class TupleCons(items: List[NonBinding])(implicit val region: Region) extends NonBinding - case class Var(name: Identifier)(implicit val region: Region) extends NonBinding - - /** - * This represents code like: - * Foo { bar: 12 } - */ - case class RecordConstructor(cons: Constructor, arg: NonEmptyList[RecordArg])(implicit val region: Region) extends NonBinding - /** - * This represents interpolated strings - */ - case class StringDecl(items: NonEmptyList[Either[NonBinding, (Region, String)]])(implicit val region: Region) extends NonBinding - /** - * This represents the list construction language - */ - case class ListDecl(list: ListLang[SpliceOrItem, NonBinding, Pattern.Parsed])(implicit val region: Region) extends NonBinding - /** - * Here are dict constructors and comprehensions - */ - case class DictDecl(list: ListLang[KVPair, NonBinding, Pattern.Parsed])(implicit val region: Region) extends NonBinding + kind: RecursionKind, + arg: NonBinding, + cases: OptIndent[NonEmptyList[(Pattern.Parsed, OptIndent[Declaration])]] + )(implicit val region: Region) + extends NonBinding + case class Matches(arg: NonBinding, pattern: Pattern.Parsed)(implicit + val region: Region + ) extends NonBinding + case class Parens(of: Declaration)(implicit val region: Region) + extends NonBinding + case class TupleCons(items: List[NonBinding])(implicit val region: Region) + extends NonBinding + case class Var(name: Identifier)(implicit val region: Region) + extends NonBinding + + /** This represents code like: Foo { bar: 12 } + */ + case class RecordConstructor(cons: Constructor, arg: NonEmptyList[RecordArg])( + implicit val region: Region + ) extends NonBinding + + /** This represents interpolated strings + */ + case class StringDecl( + items: NonEmptyList[Either[NonBinding, (Region, String)]] + )(implicit val region: Region) + extends NonBinding + + /** This represents the list construction language + */ + case class ListDecl(list: ListLang[SpliceOrItem, NonBinding, Pattern.Parsed])( + implicit val region: Region + ) extends NonBinding + + /** Here are dict constructors and comprehensions + */ + case class DictDecl(list: ListLang[KVPair, NonBinding, Pattern.Parsed])( + implicit val region: Region + ) extends NonBinding val matchKindParser: P[RecursionKind] = P.string("match") .as(RecursionKind.NonRecursive) .orElse( P.string("recur") - .as(RecursionKind.Recursive)).soft <* Parser.spaces.peek + .as(RecursionKind.Recursive) + ) + .soft <* Parser.spaces.peek - /** - * A pattern can also be a declaration in some cases - * - * TODO, patterns don't parse with regions, so we lose track of precise position information - * if we want to point to an inner portion of it - */ + /** A pattern can also be a declaration in some cases + * + * TODO, patterns don't parse with regions, so we lose track of precise + * position information if we want to point to an inner portion of it + */ def toPattern(d: NonBinding): Option[Pattern.Parsed] = d match { case Annotation(term, tpe) => toPattern(term).map(Pattern.Annotation(_, tpe)) - case Var(nm@Identifier.Constructor(_)) => - Some(Pattern.PositionalStruct( - Pattern.StructKind.Named(nm, Pattern.StructKind.Style.TupleLike), Nil)) + case Var(nm @ Identifier.Constructor(_)) => + Some( + Pattern.PositionalStruct( + Pattern.StructKind.Named(nm, Pattern.StructKind.Style.TupleLike), + Nil + ) + ) case Var(v: Bindable) => Some(Pattern.Var(v)) - case Literal(lit) => Some(Pattern.Literal(lit)) + case Literal(lit) => Some(Pattern.Literal(lit)) case StringDecl(NonEmptyList(Right((_, s)), Nil)) => Some(Pattern.Literal(Lit.Str(s))) case StringDecl(items) => - def toStrPart(p: Either[NonBinding, (Region, String)]): Option[Pattern.StrPart] = + def toStrPart( + p: Either[NonBinding, (Region, String)] + ): Option[Pattern.StrPart] = p match { - case Right((_, str)) => Some(Pattern.StrPart.LitStr(str)) + case Right((_, str)) => Some(Pattern.StrPart.LitStr(str)) case Left(Var(v: Bindable)) => Some(Pattern.StrPart.NamedStr(v)) - case _ => None + case _ => None } items.traverse(toStrPart).map(Pattern.StrPat(_)) case ListDecl(ListLang.Cons(elems)) => @@ -813,10 +964,12 @@ object Declaration { (toPattern(left), toPattern(right)).mapN { (l, r) => Pattern.union(l, r :: Nil) } - case Apply(Var(nm@Identifier.Constructor(_)), args, ApplyKind.Parens) => + case Apply(Var(nm @ Identifier.Constructor(_)), args, ApplyKind.Parens) => args.traverse(toPattern(_)).map { argPats => - Pattern.PositionalStruct(Pattern.StructKind.Named(nm, - Pattern.StructKind.Style.TupleLike), argPats.toList) + Pattern.PositionalStruct( + Pattern.StructKind.Named(nm, Pattern.StructKind.Style.TupleLike), + argPats.toList + ) } case TupleCons(ps) => ps.traverse(toPattern(_)).map { argPats => @@ -824,14 +977,15 @@ object Declaration { } case Parens(p: NonBinding) => toPattern(p) case RecordConstructor(cons, args) => - args.traverse { - case RecordArg.Simple(b) => Some(Left(b)) - case RecordArg.Pair(k, v) => - toPattern(v).map { vpat => - Right((k, vpat)) - } - } - .map(Pattern.recordPat(cons, _)(Pattern.StructKind.Named(_, _))) + args + .traverse { + case RecordArg.Simple(b) => Some(Left(b)) + case RecordArg.Pair(k, v) => + toPattern(v).map { vpat => + Right((k, vpat)) + } + } + .map(Pattern.recordPat(cons, _)(Pattern.StructKind.Named(_, _))) case _ => None } @@ -845,23 +999,28 @@ object Declaration { parser.indentBefore.mapF(Padding.parser(_)) def commentP(parser: Indy[Declaration]): Parser.Indy[Declaration] = - CommentStatement.parser( + CommentStatement + .parser( { indent => Padding.parser(P.string0(indent).with1 *> parser(indent)) } ) .region - .map { - case (r, c) => - c.on.padded match { - case nb: NonBinding => - CommentNB(CommentStatement(c.message, Padding(c.on.lines, nb)))(r) - case _ => - Comment(c)(r) - } + .map { case (r, c) => + c.on.padded match { + case nb: NonBinding => + CommentNB(CommentStatement(c.message, Padding(c.on.lines, nb)))(r) + case _ => + Comment(c)(r) + } } def commentNBP(parser: P[NonBinding]): Indy[CommentNB] = - CommentStatement.parser( - { indent => Padding.parser(P.string0(indent).with1 *> (Parser.maybeSpace.soft.with1 *> parser)) } + CommentStatement + .parser( + { indent => + Padding.parser( + P.string0(indent).with1 *> (Parser.maybeSpace.soft.with1 *> parser) + ) + } ) .region .map { case (r, c) => CommentNB(c)(r) } @@ -871,7 +1030,8 @@ object Declaration { OptIndent.indy(parser).product(Indy.lift(toEOL1) *> restP(parser)) restParser.mapF { rp => - DefStatement.parser(Pattern.bindParser, maybeSpace.with1 *> rp) + DefStatement + .parser(Pattern.bindParser, maybeSpace.with1 *> rp) .region .map { case (r, d) => DefFn(d)(r) } } @@ -892,7 +1052,9 @@ object Declaration { .map(_._2) val elifs1 = - ifelif("elif").nonEmptyList(sepIndy = Indy.toEOLIndent) <* Indy.toEOLIndent + ifelif("elif").nonEmptyList(sepIndy = + Indy.toEOLIndent + ) <* Indy.toEOLIndent val notIfs = Indy { indent => elifs1(indent).?.with1 ~ elseTerm(indent) @@ -901,14 +1063,14 @@ object Declaration { (ifelif("if") <* Indy.toEOLIndent) .cutThen(notIfs) .region - .map { - case (region, (ifcase, (optElses, elseBody))) => - val elses = - optElses match { - case None => Nil - case Some(s) => s.toList // type inference works better than fold sadly - } - IfElse(NonEmptyList(ifcase, elses), elseBody)(region) + .map { case (region, (ifcase, (optElses, elseBody))) => + val elses = + optElses match { + case None => Nil + case Some(s) => + s.toList // type inference works better than fold sadly + } + IfElse(NonEmptyList(ifcase, elses), elseBody)(region) } } @@ -919,7 +1081,8 @@ object Declaration { val q2 = '"' inner.mapF { p => - val plist = StringUtil.interpolatedString(q1, start, p, end) + val plist = StringUtil + .interpolatedString(q1, start, p, end) .orElse(StringUtil.interpolatedString(q2, start, p, end)) plist.region.map { @@ -930,22 +1093,27 @@ object Declaration { Literal(Lit.Str(str))(r) case (r, h :: tail) => StringDecl(NonEmptyList(h, tail))(r) - } + } } } def lambdaP(parser: Indy[Declaration]): Indy[Lambda] = { - val params = Indy.lift(P.char('\\') *> maybeSpace *> Pattern.bindParser.nonEmptyList) + val params = + Indy.lift(P.char('\\') *> maybeSpace *> Pattern.bindParser.nonEmptyList) - val withSlash = OptIndent.blockLike(params, parser, maybeSpace.with1 *> rightArrow) + val withSlash = OptIndent + .blockLike(params, parser, maybeSpace.with1 *> rightArrow) .region .map { case (r, (args, body)) => Lambda(args, body.get)(r) } val noSlashParamsArrow = // patterns are ambiguous with expressions wo se need backtracking - MaybeTupleOrParens.parser(Pattern.bindParser) <* (maybeSpace *> ((!Operators.operatorToken) *> rightArrow)) - - val noSlash = OptIndent.blockLike(Indy.lift(noSlashParamsArrow.backtrack), parser, P.unit) + MaybeTupleOrParens.parser( + Pattern.bindParser + ) <* (maybeSpace *> ((!Operators.operatorToken) *> rightArrow)) + + val noSlash = OptIndent + .blockLike(Indy.lift(noSlashParamsArrow.backtrack), parser, P.unit) .region .map { case (r, (rawPat, body)) => val args = rawPat match { @@ -955,7 +1123,7 @@ object Declaration { NonEmptyList.one(p) case MaybeTupleOrParens.Tuple(Nil) => // consider this the same as the pattern () - NonEmptyList.one(Pattern.tuple(Nil)) + NonEmptyList.one(Pattern.tuple(Nil)) case MaybeTupleOrParens.Tuple(h :: tail) => // we consider a top level non-empty tuple to be a list: NonEmptyList(h, tail) @@ -970,52 +1138,70 @@ object Declaration { val withTrailingExpr = expr.cutLeftP(maybeSpace) // TODO: make this strict val bp = (P.string("case") *> Parser.spaces).?.with1 *> Pattern.matchParser - //val bp = (P.string("case") *> Parser.spaces).with1 *> Pattern.matchParser + // val bp = (P.string("case") *> Parser.spaces).with1 *> Pattern.matchParser val branch = OptIndent.block(Indy.lift(bp), withTrailingExpr) - val left = Indy.lift(matchKindParser <* spaces).cutThen(arg).cutLeftP(maybeSpace) - OptIndent.block(left, branch.nonEmptyList(Indy.toEOLIndent)) + val left = + Indy.lift(matchKindParser <* spaces).cutThen(arg).cutLeftP(maybeSpace) + OptIndent + .block(left, branch.nonEmptyList(Indy.toEOLIndent)) .region .map { case (r, ((kind, arg), branches)) => Match(kind, arg, branches)(r) } } - /** - * These are keywords inside declarations (if, match, def) - * that cannot be used by identifiers - */ + /** These are keywords inside declarations (if, match, def) that cannot be + * used by identifiers + */ val keywords: Set[String] = - Set("from", "import", "if", "else", "elif", "match", "matches", "def", "recur", "struct", "enum") - - /** - * A Parser that matches keywords - */ + Set( + "from", + "import", + "if", + "else", + "elif", + "match", + "matches", + "def", + "recur", + "struct", + "enum" + ) + + /** A Parser that matches keywords + */ val keywordsP: P[Unit] = P.oneOf(keywords.toList.sorted.map(P.string(_))) <* spaces val varP: P[Var] = - (!keywordsP).with1 *> Identifier.bindableParser.region.map { case (r, i) => Var(i)(r) } + (!keywordsP).with1 *> Identifier.bindableParser.region.map { case (r, i) => + Var(i)(r) + } // this returns a Var with a Constructor or a RecordConstrutor - def recordConstructorP(indent: String, declP: P[NonBinding], noAnn: P[NonBinding]): P[NonBinding] = { + def recordConstructorP( + indent: String, + declP: P[NonBinding], + noAnn: P[NonBinding] + ): P[NonBinding] = { val ws = Parser.maybeIndentedOrSpace(indent) val kv: P[RecordArg] = RecordArg.parser(indent, noAnn) val kvs = kv.nonEmptyListOfWs(ws) // here is the record style: Foo {x: 1, ... - val recArgs = kvs.bracketed(maybeSpace.with1.soft ~ P.char('{') ~ ws, ws ~ P.char('}')) + val recArgs = + kvs.bracketed(maybeSpace.with1.soft ~ P.char('{') ~ ws, ws ~ P.char('}')) // here is tuple style: Foo(a, b) - val tupArgs = declP - .parensLines1Cut - .region - .map { case (r, args) => - { (nm: Var) => Apply(nm, args, ApplyKind.Parens)(nm.region + r) } + val tupArgs = declP.parensLines1Cut.region + .map { + case (r, args) => { (nm: Var) => + Apply(nm, args, ApplyKind.Parens)(nm.region + r) + } } - (Identifier.consParser ~ Parser.either(recArgs, tupArgs).?) - .region + (Identifier.consParser ~ Parser.either(recArgs, tupArgs).?).region .map { case (region, (n, Some(Left(args)))) => RecordConstructor(n, args)(region) @@ -1031,18 +1217,23 @@ object Declaration { case object Equals extends PatternBindKind case object LeftApplyFn extends PatternBindKind - + val parser: P[PatternBindKind] = eqP.as(Equals) | leftApplyFnP.as(LeftApplyFn) } - private def patternBind(nonBindingParser: Indy[NonBinding], decl: Indy[Declaration]): Indy[Declaration] = { + private def patternBind( + nonBindingParser: Indy[NonBinding], + decl: Indy[Declaration] + ): Indy[Declaration] = { val pat = MaybeTupleOrParens.parser(Pattern.bindParser) - val patPart = pat.region ~ (maybeSpace *> PatternBindKind.parser <* maybeSpace) + val patPart = + pat.region ~ (maybeSpace *> PatternBindKind.parser <* maybeSpace) val parser = nonBindingParser <* Indy.lift(toEOL1) // we can't cut the pattern here because we have some ambiguity in declarations // allow = to be like a block, we can continue on the next line indented - OptIndent.blockLike(Indy.lift(patPart.backtrack), parser, P.unit) + OptIndent + .blockLike(Indy.lift(patPart.backtrack), parser, P.unit) .cutThen(restP(decl)) .region .map { case (region, ((((preg, rawPat), pbk), value), decl)) => @@ -1054,22 +1245,26 @@ object Declaration { case PatternBindKind.LeftApplyFn => val pat = Pattern.fromMaybeTupleOrParens(rawPat) LeftApply(pat, preg, value.get, decl) - + } } } private def listP(p: P[NonBinding], src: P[NonBinding]): P[ListDecl] = - ListLang.parser(p, src, Pattern.bindParser) + ListLang + .parser(p, src, Pattern.bindParser) .region .map { case (r, l) => ListDecl(l)(r) } private def dictP(p: P[NonBinding], src: P[NonBinding]): P[DictDecl] = - ListLang.dictParser(p, src, Pattern.bindParser) + ListLang + .dictParser(p, src, Pattern.bindParser) .region .map { case (r, l) => DictDecl(l)(r) } - val lits: P[Literal] = Lit.integerParser.region.map { case (r, l) => Literal(l)(r) } + val lits: P[Literal] = Lit.integerParser.region.map { case (r, l) => + Literal(l)(r) + } private sealed abstract class ParseMode private object ParseMode { @@ -1085,220 +1280,253 @@ object Declaration { * we also parse Bind, Def, Comment */ private[this] val parserCache: ((ParseMode, String)) => P[Declaration] = - Memoize.memoizeDagHashedConcurrent[(ParseMode, String), P[Declaration]] { case ((pm, indent), rec) => - - // TODO: - // since we do a hard set of the mode in these, we lose the thread if we are inside a - // BranchArg so the trailing values : should be interpretted as a the branch end. - // This may actually make the file ambiguous in some cases, or at least point - // to a strange place on parse errors. - // - // I think we need to separate block-like expressions using : from NonBinding - // and make sure that we don't have block like expressions in certain places - - val recurseDecl: P[Declaration] = P.defer(rec((ParseMode.Decl, indent))) // needs to be inside a P for laziness - val recIndy: Indy[Declaration] = Indy { i => rec((ParseMode.Decl, i)) } - - // TODO: aren't NonBinding independent of indentation level> - val recNonBind: P[NonBinding] = P.defer(rec((ParseMode.NB, indent))).asInstanceOf[P[NonBinding]] - val recNBIndy: Indy[NonBinding] = Indy { i => rec((ParseMode.NB, i)).asInstanceOf[P[NonBinding]] } - - val recArg: P[NonBinding] = P.defer(rec((ParseMode.BranchArg, indent)).asInstanceOf[P[NonBinding]]) - val recArgIndy: Indy[NonBinding] = Indy { i => rec((ParseMode.BranchArg, i)).asInstanceOf[P[NonBinding]] } - - val recComp: P[NonBinding] = P.defer(rec((ParseMode.ComprehensionSource, indent))).asInstanceOf[P[NonBinding]] - - val nestedBlock: P[Region => Declaration.NonBinding] = { - /** - * we can either do: ( y = 1 - * y) - * starting a new declaration without indentation, - * or ( - * y = 1 - * y) - * where we allow more indentation. - */ - val noIndent = recurseDecl - val withIndent = Parser.newline *> Parser.spaces.string.flatMap { indent => recIndy(indent) } - maybeSpace.with1 *> (withIndent | noIndent).map { d => { (r: Region) => Parens(d)(r) } } <* maybeSpacesAndLines - } + Memoize.memoizeDagHashedConcurrent[(ParseMode, String), P[Declaration]] { + case ((pm, indent), rec) => + // TODO: + // since we do a hard set of the mode in these, we lose the thread if we are inside a + // BranchArg so the trailing values : should be interpretted as a the branch end. + // This may actually make the file ambiguous in some cases, or at least point + // to a strange place on parse errors. + // + // I think we need to separate block-like expressions using : from NonBinding + // and make sure that we don't have block like expressions in certain places + + val recurseDecl: P[Declaration] = P.defer( + rec((ParseMode.Decl, indent)) + ) // needs to be inside a P for laziness + val recIndy: Indy[Declaration] = Indy { i => rec((ParseMode.Decl, i)) } + + // TODO: aren't NonBinding independent of indentation level> + val recNonBind: P[NonBinding] = + P.defer(rec((ParseMode.NB, indent))).asInstanceOf[P[NonBinding]] + val recNBIndy: Indy[NonBinding] = Indy { i => + rec((ParseMode.NB, i)).asInstanceOf[P[NonBinding]] + } - val tupOrPar: P[NonBinding] = - // TODO: the backtrack here is bad... - Parser.parens(((maybeSpacesAndLines.with1.soft *> ((recNonBind <* (!(maybeSpace ~ bindOp))).backtrack <* maybeSpacesAndLines)) - .tupleOrParens0 - .map { - case Left(p) => { (r: Region) => Parens(p)(r) } - case Right(tup) => { (r: Region) => TupleCons(tup.toList)(r) } - }) - .orElse(nestedBlock) - // or it could be () which is just unit - .orElse(P.pure({ (r: Region) => TupleCons(Nil)(r) })) - , P.unit) - .region - .map { case (r, fn) => fn(r) } - - // since x -> y: t will parse like x -> (y: t) - // if we are in a branch arg, we can't parse annotations on the body of the lambda - val lambBody = if (pm == ParseMode.BranchArg) recArgIndy.asInstanceOf[Indy[Declaration]] else recIndy - val ternaryElseP = if (pm == ParseMode.BranchArg) recArg else recNonBind - - val allNonBind: P[NonBinding] = - P.defer( - P.oneOf( - lambdaP(lambBody)(indent) :: - ifElseP(recArgIndy, recIndy)(indent) :: - matchP(recArgIndy, recIndy)(indent) :: - dictP(recArg, recComp) :: - varP :: - listP(recNonBind, recComp) :: - lits :: - stringDeclOrLit(recNBIndy)(indent) :: - tupOrPar :: - recordConstructorP(indent, recNonBind, recArg) :: - // TODO: comment is ambiguous with binding/non-binding... - // so it prevents us commenting a binding statement - commentNBP(recNonBind)(indent) :: - Nil)) - - /* - * This is where we parse application, either direct, or dot-style - */ - val applied: P[NonBinding] = { - // here we are using . syntax foo.bar(1, 2) - // we also allow foo.(anyExpression)(1, 2) - val fn = varP.orElse(recNonBind.parensCut) - val slashcontinuation = ((maybeSpace ~ P.char('\\') ~ toEOL1).backtrack ~ Parser.maybeSpacesAndLines).?.void - // 0 or more args - val params0 = recNonBind.parensLines0Cut - val dotApply: P[NonBinding => NonBinding] = - (slashcontinuation.with1 *> P.char('.') *> (fn ~ params0)) - .region - .map { case (r2, (fn, args)) => + val recArg: P[NonBinding] = P.defer( + rec((ParseMode.BranchArg, indent)).asInstanceOf[P[NonBinding]] + ) + val recArgIndy: Indy[NonBinding] = Indy { i => + rec((ParseMode.BranchArg, i)).asInstanceOf[P[NonBinding]] + } - { (head: NonBinding) => Apply(fn, NonEmptyList(head, args), ApplyKind.Dot)(head.region + r2) } - } + val recComp: P[NonBinding] = P + .defer(rec((ParseMode.ComprehensionSource, indent))) + .asInstanceOf[P[NonBinding]] + + val nestedBlock: P[Region => Declaration.NonBinding] = { + + /** we can either do: ( y = 1 y) starting a new declaration without + * indentation, or ( y = 1 y) where we allow more indentation. + */ + val noIndent = recurseDecl + val withIndent = Parser.newline *> Parser.spaces.string.flatMap { + indent => recIndy(indent) + } + maybeSpace.with1 *> (withIndent | noIndent).map { d => + { (r: Region) => Parens(d)(r) } + } <* maybeSpacesAndLines + } - // 1 or more args - val params1 = recNonBind.parensLines1Cut - // here we directly call a function foo(1, 2) - val applySuffix: P[NonBinding => NonBinding] = - params1 + val tupOrPar: P[NonBinding] = + // TODO: the backtrack here is bad... + Parser + .parens( + ((maybeSpacesAndLines.with1.soft *> ((recNonBind <* (!(maybeSpace ~ bindOp))).backtrack <* maybeSpacesAndLines)).tupleOrParens0 + .map { + case Left(p) => { (r: Region) => Parens(p)(r) } + case Right(tup) => { (r: Region) => TupleCons(tup.toList)(r) } + }) + .orElse(nestedBlock) + // or it could be () which is just unit + .orElse(P.pure({ (r: Region) => TupleCons(Nil)(r) })), + P.unit + ) .region - .map { case (r, args) => - { (fn: NonBinding) => Apply(fn, args, ApplyKind.Parens)(fn.region + r) } + .map { case (r, fn) => fn(r) } + + // since x -> y: t will parse like x -> (y: t) + // if we are in a branch arg, we can't parse annotations on the body of the lambda + val lambBody = + if (pm == ParseMode.BranchArg) + recArgIndy.asInstanceOf[Indy[Declaration]] + else recIndy + val ternaryElseP = if (pm == ParseMode.BranchArg) recArg else recNonBind + + val allNonBind: P[NonBinding] = + P.defer( + P.oneOf( + lambdaP(lambBody)(indent) :: + ifElseP(recArgIndy, recIndy)(indent) :: + matchP(recArgIndy, recIndy)(indent) :: + dictP(recArg, recComp) :: + varP :: + listP(recNonBind, recComp) :: + lits :: + stringDeclOrLit(recNBIndy)(indent) :: + tupOrPar :: + recordConstructorP(indent, recNonBind, recArg) :: + // TODO: comment is ambiguous with binding/non-binding... + // so it prevents us commenting a binding statement + commentNBP(recNonBind)(indent) :: + Nil + ) + ) + + /* + * This is where we parse application, either direct, or dot-style + */ + val applied: P[NonBinding] = { + // here we are using . syntax foo.bar(1, 2) + // we also allow foo.(anyExpression)(1, 2) + val fn = varP.orElse(recNonBind.parensCut) + val slashcontinuation = ((maybeSpace ~ P.char( + '\\' + ) ~ toEOL1).backtrack ~ Parser.maybeSpacesAndLines).?.void + // 0 or more args + val params0 = recNonBind.parensLines0Cut + val dotApply: P[NonBinding => NonBinding] = + (slashcontinuation.with1 *> P.char('.') *> (fn ~ params0)).region + .map { + case (r2, (fn, args)) => { (head: NonBinding) => + Apply(fn, NonEmptyList(head, args), ApplyKind.Dot)( + head.region + r2 + ) + } + } + + // 1 or more args + val params1 = recNonBind.parensLines1Cut + // here we directly call a function foo(1, 2) + val applySuffix: P[NonBinding => NonBinding] = + params1.region + .map { + case (r, args) => { (fn: NonBinding) => + Apply(fn, args, ApplyKind.Parens)(fn.region + r) + } + } + + def repFn[A](fn: P[A => A]): P0[A => A] = + fn.rep0.map { opList => + { (a: A) => opList.foldLeft(a) { (arg, fn) => fn(arg) } } } - def repFn[A](fn: P[A => A]): P0[A => A] = - fn.rep0.map { opList => - { (a: A) => opList.foldLeft(a) { (arg, fn) => fn(arg) } } + (allNonBind ~ repFn(dotApply.orElse(applySuffix))) + .map { case (a, f) => f(a) } + } + // lower priority than calls is type annotation + val annotated: P[NonBinding] = + if (pm == ParseMode.BranchArg) applied + else { + val an: P[NonBinding => NonBinding] = + TypeRef.annotationParser + // TODO remove this backtrack, + // currently we can confuse ending a block with type annotation + // without backtracking here due to nesting losing track of + // when a trailing item is in a BranchArg in e.g. match or if bodies + .backtrack.region + .map { + case (r, tpe) => { (nb: NonBinding) => + Annotation(nb, tpe)(nb.region + r) + } + } + + applied.maybeAp(an) } - (allNonBind ~ repFn(dotApply.orElse(applySuffix))) - .map { case (a, f) => f(a) } - } - // lower priority than calls is type annotation - val annotated: P[NonBinding] = - if (pm == ParseMode.BranchArg) applied - else { - val an: P[NonBinding => NonBinding] = - TypeRef.annotationParser - // TODO remove this backtrack, - // currently we can confuse ending a block with type annotation - // without backtracking here due to nesting losing track of - // when a trailing item is in a BranchArg in e.g. match or if bodies - .backtrack - .region - .map { case (r, tpe) => - { (nb: NonBinding) => Annotation(nb, tpe)(nb.region + r) } + // matched + val matched: P[NonBinding] = { + // x matches p + val matchesOp = + ((maybeSpace.with1 *> P.string( + "matches" + ) *> spaces).backtrack *> Pattern.matchParser).region + .map { + case (region, pat) => { (nb: NonBinding) => + Matches(nb, pat)(nb.region + region) + } } + .rep + .map { fns => fns.toList.reduceLeft(_.andThen(_)) } - applied.maybeAp(an) + annotated.maybeAp(matchesOp) } - // matched - val matched: P[NonBinding] = { - // x matches p - val matchesOp = - ((maybeSpace.with1 *> P.string("matches") *> spaces).backtrack *> Pattern.matchParser) - .region - .map { case (region, pat) => - - { (nb: NonBinding) => Matches(nb, pat)(nb.region + region) } + // Applying is higher precedence than any operators + // now parse an operator apply + def postOperators(nb: P[NonBinding]): P[NonBinding] = { + + def convert(form: Operators.Formula[NonBinding]): NonBinding = + form match { + case Operators.Formula.Sym(r) => r + case Operators.Formula.Op(left, op, right) => + val leftD = convert(left) + val rightD = convert(right) + // `op`(l, r) + ApplyOp(leftD, Identifier.Operator(op), rightD) } - .rep - .map { fns => fns.toList.reduceLeft(_.andThen(_)) } - annotated.maybeAp(matchesOp) - } + // one or more operators + val ops: P[NonBinding => Operators.Formula[NonBinding]] = + maybeSpace.with1.soft *> ((!bindOp).with1 *> Operators.Formula + .infixOps1(nb)) - // Applying is higher precedence than any operators - // now parse an operator apply - def postOperators(nb: P[NonBinding]): P[NonBinding] = { - - def convert(form: Operators.Formula[NonBinding]): NonBinding = - form match { - case Operators.Formula.Sym(r) => r - case Operators.Formula.Op(left, op, right) => - val leftD = convert(left) - val rightD = convert(right) - // `op`(l, r) - ApplyOp(leftD, Identifier.Operator(op), rightD) + // This already parses as many as it can, so we don't need repFn + val form = ops.map { fn => + { (d: NonBinding) => convert(fn(d)) } } - // one or more operators - val ops: P[NonBinding => Operators.Formula[NonBinding]] = - maybeSpace.with1.soft *> ((!bindOp).with1 *> Operators.Formula.infixOps1(nb)) - - // This already parses as many as it can, so we don't need repFn - val form = ops.map { fn => - - { (d: NonBinding) => convert(fn(d)) } + nb.maybeAp(form) } - nb.maybeAp(form) - } + // here is if/ternary operator + // it fully recurses on the else branch, which will parse any repeated ternaryies + // so no need to repeat here for correct precedence + val ternary: P[NonBinding => NonBinding] = + (((spaces *> P.string( + "if" + ) *> spaces).backtrack *> recNonBind) ~ (spaces *> keySpace( + "else" + ) *> ternaryElseP)) + .map { + case (cond, falseCase) => { (trueCase: NonBinding) => + Ternary(trueCase, cond, falseCase) + } + } - // here is if/ternary operator - // it fully recurses on the else branch, which will parse any repeated ternaryies - // so no need to repeat here for correct precedence - val ternary: P[NonBinding => NonBinding] = - (((spaces *> P.string("if") *> spaces).backtrack *> recNonBind) ~ (spaces *> keySpace("else") *> ternaryElseP)) - .map { case (cond, falseCase) => - { (trueCase: NonBinding) => Ternary(trueCase, cond, falseCase) } - } + val finalNonBind: P[NonBinding] = + if (pm != ParseMode.ComprehensionSource) + postOperators(matched).maybeAp(ternary) + else postOperators(matched) - val finalNonBind: P[NonBinding] = - if (pm != ParseMode.ComprehensionSource) postOperators(matched).maybeAp(ternary) - else postOperators(matched) - - if (pm != ParseMode.Decl) finalNonBind - else { - val finalBind: P[Declaration] = P.defer( - P.oneOf( - // these have keywords which need to be parsed before var (def, match, if) - defP(recIndy)(indent) :: - // these are not ambiguous with patterns - commentP(recIndy)(indent) :: - /* - * challenge is that not all Declarations are Patterns, and not - * all Patterns are Declarations. So, bindings, which are: pattern = declaration - * is a bit hard. This also makes cuts a bit dangerous, since this ambiguity - * between pattern and declaration means if we use cuts too aggressively, we - * will fail. - * - * If we parse a declaration first, if we see = we need to convert - * to pattern. If we parse a pattern, but it was actually a declaration, we need - * to convert there. This code tries to parse as a declaration first, then converts - * it to pattern if we see an = - */ - patternBind(recNBIndy, recIndy)(indent) :: - Nil) + if (pm != ParseMode.Decl) finalNonBind + else { + val finalBind: P[Declaration] = P.defer( + P.oneOf( + // these have keywords which need to be parsed before var (def, match, if) + defP(recIndy)(indent) :: + // these are not ambiguous with patterns + commentP(recIndy)(indent) :: + /* + * challenge is that not all Declarations are Patterns, and not + * all Patterns are Declarations. So, bindings, which are: pattern = declaration + * is a bit hard. This also makes cuts a bit dangerous, since this ambiguity + * between pattern and declaration means if we use cuts too aggressively, we + * will fail. + * + * If we parse a declaration first, if we see = we need to convert + * to pattern. If we parse a pattern, but it was actually a declaration, we need + * to convert there. This code tries to parse as a declaration first, then converts + * it to pattern if we see an = + */ + patternBind(recNBIndy, recIndy)(indent) :: + Nil + ) ) - // we have to parse non-binds last - finalBind.orElse(finalNonBind) - } + // we have to parse non-binds last + finalBind.orElse(finalNonBind) + } } val parser: Indy[Declaration] = @@ -1306,7 +1534,9 @@ object Declaration { val nonBindingParser: Indy[NonBinding] = Indy { i => parserCache((ParseMode.NB, i)) }.asInstanceOf[Indy[NonBinding]] val nonBindingParserNoTern: Indy[NonBinding] = - Indy { i => parserCache((ParseMode.ComprehensionSource, i)) }.asInstanceOf[Indy[NonBinding]] + Indy { i => parserCache((ParseMode.ComprehensionSource, i)) } + .asInstanceOf[Indy[NonBinding]] val nonBindingParserNoAnn: Indy[NonBinding] = - Indy { i => parserCache((ParseMode.BranchArg, i)) }.asInstanceOf[Indy[NonBinding]] + Indy { i => parserCache((ParseMode.BranchArg, i)) } + .asInstanceOf[Indy[NonBinding]] } diff --git a/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala b/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala index e27f0a639..2c91112a5 100644 --- a/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala +++ b/core/src/main/scala/org/bykn/bosatsu/DefRecursionCheck.scala @@ -7,20 +7,19 @@ import cats.implicits._ import Identifier.Bindable -/** - * Recursion in bosatsu is only allowed on a substructural match - * of one of the parameters to the def. This strict rule, along - * with strictly finite data, ensures that all recursion terminates - * - * The rules are as follows: - * 0. defs may not be shadowed. This makes checking for legal recursion easier - * 1. until we reach a recur match, we cannot access an outer def name. We want to avoid aliasing - * 2. a recur match must occur on one of the literal parameters to the def, and there can - * be only one recur match - * 3. inside each branch of the recur match, we may only recur on substructures in the match - * position. - * 4. if there is a recur match, there must be at least one real recursion - */ +/** Recursion in bosatsu is only allowed on a substructural match of one of the + * parameters to the def. This strict rule, along with strictly finite data, + * ensures that all recursion terminates + * + * The rules are as follows: 0. defs may not be shadowed. This makes checking + * for legal recursion easier + * 1. until we reach a recur match, we cannot access an outer def name. We + * want to avoid aliasing 2. a recur match must occur on one of the + * literal parameters to the def, and there can be only one recur match 3. + * inside each branch of the recur match, we may only recur on + * substructures in the match position. 4. if there is a recur match, + * there must be at least one real recursion + */ object DefRecursionCheck { type Res = ValidatedNel[RecursionError, Unit] @@ -29,30 +28,40 @@ object DefRecursionCheck { def region: Region def message: String } - case class InvalidRecursion(name: Bindable, illegalPosition: Region) extends RecursionError { + case class InvalidRecursion(name: Bindable, illegalPosition: Region) + extends RecursionError { def region = illegalPosition def message = s"invalid recursion on ${name.sourceCodeRepr}" } - case class IllegalShadow(fnname: Bindable, decl: Declaration) extends RecursionError { + case class IllegalShadow(fnname: Bindable, decl: Declaration) + extends RecursionError { def region = decl.region - def message = s"illegal shadowing on: ${fnname.sourceCodeRepr}. Recursive shadowing of def names disallowed" + def message = + s"illegal shadowing on: ${fnname.sourceCodeRepr}. Recursive shadowing of def names disallowed" } case class UnexpectedRecur(decl: Declaration.Match) extends RecursionError { def region = decl.region def message = "unexpected recur: may only appear unnested inside a def" } - case class RecurNotOnArg(decl: Declaration.Match, - fnname: Bindable, - args: NonEmptyList[NonEmptyList[Pattern.Parsed]]) extends RecursionError { + case class RecurNotOnArg( + decl: Declaration.Match, + fnname: Bindable, + args: NonEmptyList[NonEmptyList[Pattern.Parsed]] + ) extends RecursionError { def region = decl.region def message = { val argsDoc = - Doc.intercalate(Doc.empty, + Doc.intercalate( + Doc.empty, args.toList.map { group => (Doc.char('(') + - Doc.intercalate(Doc.comma + Doc.line, - group.toList.map { pat => Pattern.document[TypeRef].document(pat) }) + + Doc.intercalate( + Doc.comma + Doc.line, + group.toList.map { pat => + Pattern.document[TypeRef].document(pat) + } + ) + Doc.char(')')).grouped } ) @@ -60,25 +69,35 @@ object DefRecursionCheck { s"recur not on an argument to the def of ${fnname.sourceCodeRepr}, args: $argStr" } } - case class RecursionArgNotVar(fnname: Bindable, invalidArg: Declaration) extends RecursionError { + case class RecursionArgNotVar(fnname: Bindable, invalidArg: Declaration) + extends RecursionError { def region = invalidArg.region - def message = s"recursion in ${fnname.sourceCodeRepr} is not on a name (expect a name which is exactly a arg to the def)" + def message = + s"recursion in ${fnname.sourceCodeRepr} is not on a name (expect a name which is exactly a arg to the def)" } - case class RecursionNotSubstructural(fnname: Bindable, recurPat: Pattern.Parsed, arg: Declaration.Var) extends RecursionError { + case class RecursionNotSubstructural( + fnname: Bindable, + recurPat: Pattern.Parsed, + arg: Declaration.Var + ) extends RecursionError { def region = arg.region def message = s"recursion in ${fnname.sourceCodeRepr} not substructual" } - case class RecursiveDefNoRecur(defstmt: DefStatement[Pattern.Parsed, Declaration], recur: Declaration.Match) extends RecursionError { + case class RecursiveDefNoRecur( + defstmt: DefStatement[Pattern.Parsed, Declaration], + recur: Declaration.Match + ) extends RecursionError { def region = recur.region - def message = s"recur but no recursive call to ${defstmt.name.sourceCodeRepr}" + def message = + s"recur but no recursive call to ${defstmt.name.sourceCodeRepr}" } - /** - * Check a statement that all inner declarations contain legal - * recursion, or none at all. Note, we don't check for cases that will be caught - * by typechecking: namely, when we have nonrecursive defs, their names are not - * in scope during typechecking, so illegal recursion there simply won't typecheck. - */ + /** Check a statement that all inner declarations contain legal recursion, or + * none at all. Note, we don't check for cases that will be caught by + * typechecking: namely, when we have nonrecursive defs, their names are not + * in scope during typechecking, so illegal recursion there simply won't + * typecheck. + */ def checkStatement(s: Statement): Res = { import Statement._ import Impl._ @@ -124,14 +143,17 @@ object DefRecursionCheck { (dn == n) || outer.defNamesContain(n) } - def inDef(fnname: Bindable, args: NonEmptyList[NonEmptyList[Pattern.Parsed]]): InDef = + def inDef( + fnname: Bindable, + args: NonEmptyList[NonEmptyList[Pattern.Parsed]] + ): InDef = InDef(this, fnname, args, Set.empty) } sealed abstract class InDefState extends State { final def inDef: InDef = this match { - case id @ InDef(_, _, _, _) => id - case InDefRecurred(ir, _, _, _, _) => ir.inDef + case id @ InDef(_, _, _, _) => id + case InDefRecurred(ir, _, _, _, _) => ir.inDef case InRecurBranch(InDefRecurred(ir, _, _, _, _), _, _) => ir.inDef } @@ -139,7 +161,12 @@ object DefRecursionCheck { } case object TopLevel extends State - case class InDef(outer: State, fnname: Bindable, args: NonEmptyList[NonEmptyList[Pattern.Parsed]], localScope: Set[Bindable]) extends InDefState { + case class InDef( + outer: State, + fnname: Bindable, + args: NonEmptyList[NonEmptyList[Pattern.Parsed]], + localScope: Set[Bindable] + ) extends InDefState { def addLocal(b: Bindable): InDef = InDef(outer, fnname, args, localScope + b) @@ -149,30 +176,34 @@ object DefRecursionCheck { // This is eta-expansion of the function name as a lambda so we can check using the lambda rule def asLambda(region: Region): Declaration.Lambda = { - val allNames = Iterator.iterate(0)(_ + 1).map { idx => Identifier.Name(s"a$idx") }.filterNot(_ == fnname) - + val allNames = Iterator + .iterate(0)(_ + 1) + .map { idx => Identifier.Name(s"a$idx") } + .filterNot(_ == fnname) + val func = cats.Functor[NonEmptyList].compose[NonEmptyList] // we allocate the names first. There is only one name inside: fnname val argsB = func.map(args)(_ => allNames.next()) val argsV: NonEmptyList[NonEmptyList[Declaration.NonBinding]] = - func.map(argsB)( - n => Declaration.Var(n)(region) - ) + func.map(argsB)(n => Declaration.Var(n)(region)) val argsP: NonEmptyList[NonEmptyList[Pattern.Parsed]] = - func.map(argsB)( - n => Pattern.Var(n) - ) + func.map(argsB)(n => Pattern.Var(n)) - // fn == (x, y) -> z -> f(x, y)(z) - val body = argsV.toList.foldLeft(Declaration.Var(fnname)(region): Declaration.NonBinding) { (called, group) => - Declaration.Apply(called, group, Declaration.ApplyKind.Parens)(region) + // fn == (x, y) -> z -> f(x, y)(z) + val body = argsV.toList.foldLeft( + Declaration.Var(fnname)(region): Declaration.NonBinding + ) { (called, group) => + Declaration.Apply(called, group, Declaration.ApplyKind.Parens)(region) } - def lambdify(args: NonEmptyList[NonEmptyList[Pattern.Parsed]], body: Declaration): Declaration.Lambda = { - val body1 = args.tail match { - case Nil => body + def lambdify( + args: NonEmptyList[NonEmptyList[Pattern.Parsed]], + body: Declaration + ): Declaration.Lambda = { + val body1 = args.tail match { + case Nil => body case h :: tail => lambdify(NonEmptyList(h, tail), body) } Declaration.Lambda(args.head, body1)(region) @@ -181,10 +212,20 @@ object DefRecursionCheck { lambdify(argsP, body) } } - case class InDefRecurred(inRec: InDef, group: Int, index: Int, recur: Declaration.Match, recCount: Int) extends InDefState { + case class InDefRecurred( + inRec: InDef, + group: Int, + index: Int, + recur: Declaration.Match, + recCount: Int + ) extends InDefState { def incRecCount: InDefRecurred = copy(recCount = recCount + 1) } - case class InRecurBranch(inRec: InDefRecurred, branch: Pattern.Parsed, allowedNames: Set[Bindable]) extends InDefState { + case class InRecurBranch( + inRec: InDefRecurred, + branch: Pattern.Parsed, + allowedNames: Set[Bindable] + ) extends InDefState { def incRecCount: InRecurBranch = copy(inRec = inRec.incRecCount) } @@ -192,10 +233,11 @@ object DefRecursionCheck { * What is the index into the list of def arguments where we are doing our recursion */ def getRecurIndex( - fnname: Bindable, - args: NonEmptyList[NonEmptyList[Pattern.Parsed]], - m: Declaration.Match, - locals: Set[Bindable]): ValidatedNel[RecursionError, (Int, Int)] = { + fnname: Bindable, + args: NonEmptyList[NonEmptyList[Pattern.Parsed]], + m: Declaration.Match, + locals: Set[Bindable] + ): ValidatedNel[RecursionError, (Int, Int)] = { import Declaration._ m.arg match { case Var(v) => @@ -209,7 +251,6 @@ object DefRecursionCheck { if item.topNames.contains(v) } yield (gidx, idx) - if (idxes.hasNext) Validated.valid(idxes.next()) else Validated.invalidNel(RecurNotOnArg(m, fnname, args)) } @@ -222,9 +263,14 @@ object DefRecursionCheck { * Check that decl is a strict substructure of pat. We do this by making sure decl is a Var * and that var is one of the strict substrutures of the pattern. */ - def allowedRecursion(fnname: Bindable, pat: Pattern.Parsed, names: Set[Bindable], decl: Declaration): Res = + def allowedRecursion( + fnname: Bindable, + pat: Pattern.Parsed, + names: Set[Bindable], + decl: Declaration + ): Res = decl match { - case v@Declaration.Var(nm: Bindable) => + case v @ Declaration.Var(nm: Bindable) => if (names.contains(nm)) unitValid else Validated.invalidNel(RecursionNotSubstructural(fnname, pat, v)) case _ => @@ -240,20 +286,25 @@ object DefRecursionCheck { * for the algorithm here, but also for human readers to see that recursion is total */ def checkForIllegalBinds[A]( - state: State, - bs: Iterable[Bindable], - decl: Declaration)(next: ValidatedNel[RecursionError, A]): ValidatedNel[RecursionError, A] = { - val outerSet = state.outerDefNames - if (outerSet.isEmpty) next - else { - NonEmptyList.fromList(bs.iterator.filter(outerSet).toList.sorted) match { - case Some(nel) => - Validated.invalid(nel.map(IllegalShadow(_, decl))) - case None => - next - } + state: State, + bs: Iterable[Bindable], + decl: Declaration + )( + next: ValidatedNel[RecursionError, A] + ): ValidatedNel[RecursionError, A] = { + val outerSet = state.outerDefNames + if (outerSet.isEmpty) next + else { + NonEmptyList.fromList( + bs.iterator.filter(outerSet).toList.sorted + ) match { + case Some(nel) => + Validated.invalid(nel.map(IllegalShadow(_, decl))) + case None => + next } } + } /* * Unfortunately we lose the Applicative structure inside Declaration checking. @@ -276,19 +327,22 @@ object DefRecursionCheck { val unitSt: St[Unit] = pureSt(()) def checkForIllegalBindsSt[A]( - bs: Iterable[Bindable], - decl: Declaration): St[Unit] = - for { - state <- getSt - _ <- toSt(checkForIllegalBinds(state, bs, decl)(unitValid)) - _ <- (state match { - case id@InDef(_, _, _, _) => setSt(bs.foldLeft(id)(_.addLocal(_))) - case _ => unitSt - }) - } yield () - - private def argsOnDefName(fn: Declaration, - groups: NonEmptyList[NonEmptyList[Declaration]]): Option[(Bindable, NonEmptyList[NonEmptyList[Declaration]])] = + bs: Iterable[Bindable], + decl: Declaration + ): St[Unit] = + for { + state <- getSt + _ <- toSt(checkForIllegalBinds(state, bs, decl)(unitValid)) + _ <- (state match { + case id @ InDef(_, _, _, _) => setSt(bs.foldLeft(id)(_.addLocal(_))) + case _ => unitSt + }) + } yield () + + private def argsOnDefName( + fn: Declaration, + groups: NonEmptyList[NonEmptyList[Declaration]] + ): Option[(Bindable, NonEmptyList[NonEmptyList[Declaration]])] = fn match { case Declaration.Var(nm: Bindable) => Some((nm, groups)) case Declaration.Apply(fn1, args, _) => @@ -303,9 +357,11 @@ object DefRecursionCheck { .flatMapN { case (a, InRecurBranch(ir1, b1, _)) => setSt(InRecurBranch(ir1, b1, names)).as(a) - // $COVERAGE-OFF$ this should be unreachable + // $COVERAGE-OFF$ this should be unreachable case (_, unexpected) => - sys.error(s"invariant violation expected InRecurBranch: start = $start, end = $unexpected") + sys.error( + s"invariant violation expected InRecurBranch: start = $start, end = $unexpected" + ) } case notRecur => sys.error(s"called setNames on $notRecur with names: $newNames") @@ -321,19 +377,24 @@ object DefRecursionCheck { setSt(InRecurBranch(ir1, b1, names)).as(a) // $COVERAGE-OFF$ this should be unreachable case (_, unexpected) => - sys.error(s"invariant violation expected InRecurBranch: start = $start, end = $unexpected") + sys.error( + s"invariant violation expected InRecurBranch: start = $start, end = $unexpected" + ) // $COVERAGE-ON$ this should be unreachable } case _ => in } - def checkApply(fn: Declaration, args: NonEmptyList[Declaration], region: Region): St[Unit] = + def checkApply( + fn: Declaration, + args: NonEmptyList[Declaration], + region: Region + ): St[Unit] = getSt.flatMap { case TopLevel => // without any recursion, normal typechecking will detect bad states: checkDecl(fn) *> args.traverse_(checkDecl) - case irb@InRecurBranch(inrec, branch, names) => - + case irb @ InRecurBranch(inrec, branch, names) => argsOnDefName(fn, NonEmptyList.one(args)) match { case Some((nm, groups)) => if (nm == irb.defname) { @@ -347,39 +408,38 @@ object DefRecursionCheck { toSt(allowedRecursion(irb.defname, branch, names, arg)) *> setSt(irb.incRecCount) // we have recurred again } - } - else if (irb.defNamesContain(nm)) { + } else if (irb.defNamesContain(nm)) { failSt(InvalidRecursion(nm, region)) - } - else if (names.contains(nm)) { + } else if (names.contains(nm)) { // we are calling a reachable function. Any lambda args are new names: args.traverse_[St, Unit] { case Declaration.Lambda(args, body) => val names1 = args.toList.flatMap(_.names) unionNames(names1)(checkDecl(body)) - case v@Declaration.Var(fn: Bindable) if irb.defname == fn => - val Declaration.Lambda(args, body) = irb.inDef.asLambda(v.region) + case v @ Declaration.Var(fn: Bindable) if irb.defname == fn => + val Declaration.Lambda(args, body) = + irb.inDef.asLambda(v.region) val names1 = args.toList.flatMap(_.names) unionNames(names1)(checkDecl(body)) case notLambda => checkDecl(notLambda) } - } - else { + } else { // traverse converting Var(name) to the lambda version to use the above check // not a recursive call args.traverse_(checkDecl) } case None => // this isn't a recursive call - checkDecl(fn) *> args.traverse_(checkDecl) + checkDecl(fn) *> args.traverse_(checkDecl) } case ir: InDefState => // we have either not yet, or already done the recursion argsOnDefName(fn, NonEmptyList.one(args)) match { - case Some((nm, _)) if ir.defNamesContain(nm) => failSt(InvalidRecursion(nm, region)) + case Some((nm, _)) if ir.defNamesContain(nm) => + failSt(InvalidRecursion(nm, region)) case _ => - checkDecl(fn) *> args.traverse_(checkDecl) - } + checkDecl(fn) *> args.traverse_(checkDecl) + } } /* * With the given state, check the given Declaration to see if @@ -390,13 +450,17 @@ object DefRecursionCheck { decl match { case Annotation(t, _) => checkDecl(t) case Apply(fn, args, _) => - checkApply(fn, args, decl.region) + checkApply(fn, args, decl.region) case ApplyOp(left, op, right) => - checkApply(Var(op)(decl.region), NonEmptyList(left, right :: Nil), decl.region) + checkApply( + Var(op)(decl.region), + NonEmptyList(left, right :: Nil), + decl.region + ) case Binding(BindingStatement(pat, thisDecl, next)) => checkForIllegalBindsSt(pat.names, decl) *> - checkDecl(thisDecl) *> - filterNames(pat.names)(checkDecl(next.padded)) + checkDecl(thisDecl) *> + filterNames(pat.names)(checkDecl(next.padded)) case Comment(cs) => checkDecl(cs.on.padded) case CommentNB(cs) => @@ -414,7 +478,7 @@ object DefRecursionCheck { } val e = checkDecl(elseCase.get) ifs *> e - case la@LeftApply(_, _, _, _) => + case la @ LeftApply(_, _, _, _) => checkDecl(la.rewrite) case Ternary(t, c, f) => checkDecl(t) *> checkDecl(c) *> checkDecl(f) @@ -433,10 +497,11 @@ object DefRecursionCheck { filterNames(pat.names)(checkDecl(next.get)) } argRes *> optRes - case recur@Match(RecursionKind.Recursive, _, cases) => + case recur @ Match(RecursionKind.Recursive, _, cases) => // this is a state change getSt.flatMap { - case TopLevel | InRecurBranch(_, _, _) | InDefRecurred(_, _, _, _, _) => + case TopLevel | InRecurBranch(_, _, _) | + InDefRecurred(_, _, _, _, _) => failSt(UnexpectedRecur(recur)) case InDef(_, defname, args, locals) => toSt(getRecurIndex(defname, args, recur, locals)).flatMap { idx => @@ -444,24 +509,24 @@ object DefRecursionCheck { // parent state def beginBranch(pat: Pattern.Parsed): St[Unit] = getSt.flatMap { - case ir@InDef(_, _, _, _) => + case ir @ InDef(_, _, _, _) => val rec = ir.setRecur(idx, recur) setSt(rec) *> beginBranch(pat) - case irr@InDefRecurred(_, _, _, _, _) => + case irr @ InDefRecurred(_, _, _, _, _) => setSt(InRecurBranch(irr, pat, pat.substructures.toSet)) case illegal => // $COVERAGE-OFF$ this should be unreachable sys.error(s"unreachable: $pat -> $illegal") - // $COVERAGE-ON$ - } + // $COVERAGE-ON$ + } val endBranch: St[Unit] = getSt.flatMap { case InRecurBranch(irr, _, _) => setSt(irr) - case illegal => + case illegal => // $COVERAGE-OFF$ this should be unreachable sys.error(s"unreachable end state: $illegal") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } cases.get.traverse_ { case (pat, next) => @@ -473,7 +538,7 @@ object DefRecursionCheck { } yield () } } - } + } case Matches(a, _) => // patterns don't use values checkDecl(a) @@ -490,7 +555,8 @@ object DefRecursionCheck { unitSt case ir: InDefState => // if this were an apply, it would have been handled by Apply(Var(... - if (ir.defNamesContain(v)) failSt(InvalidRecursion(v, decl.region)) + if (ir.defNamesContain(v)) + failSt(InvalidRecursion(v, decl.region)) else unitSt } case StringDecl(parts) => @@ -537,7 +603,10 @@ object DefRecursionCheck { * Binds are not allowed to be recursive, only defs, so here we just make sure * none of the free variables of the pattern are used in decl */ - def checkDef[A](state: State, defstmt: DefStatement[Pattern.Parsed, (OptIndent[Declaration], A)]): Res = { + def checkDef[A]( + state: State, + defstmt: DefStatement[Pattern.Parsed, (OptIndent[Declaration], A)] + ): Res = { val body = defstmt.result._1.get val nameArgs = defstmt.args.toList.flatMap(_.patternNames) val state1 = state.inDef(defstmt.name, defstmt.args) @@ -551,11 +620,18 @@ object DefRecursionCheck { unitSt case InDefRecurred(_, _, _, recur, 0) => // we hit a recur, but we didn't recurse - failSt[Unit](RecursiveDefNoRecur(defstmt.copy(result = defstmt.result._1.get), recur)) + failSt[Unit]( + RecursiveDefNoRecur( + defstmt.copy(result = defstmt.result._1.get), + recur + ) + ) case unreachable => // $COVERAGE-OFF$ this should be unreachable - sys.error(s"we would like to prove in the types we can't get here: $unreachable, $defstmt"): St[Unit] - // $COVERAGE-ON$ + sys.error( + s"we would like to prove in the types we can't get here: $unreachable, $defstmt" + ): St[Unit] + // $COVERAGE-ON$ }) // Note a def can't change the state // we either have a valid nested def, or we don't diff --git a/core/src/main/scala/org/bykn/bosatsu/DefStatement.scala b/core/src/main/scala/org/bykn/bosatsu/DefStatement.scala index 932814983..803869cff 100644 --- a/core/src/main/scala/org/bykn/bosatsu/DefStatement.scala +++ b/core/src/main/scala/org/bykn/bosatsu/DefStatement.scala @@ -29,14 +29,16 @@ object DefStatement { import defs._ val res = retType.fold(Doc.empty) { t => arrow + t.toDoc } val taDoc = typeArgs match { - case None => Doc.empty - case Some(ta) => TypeRef.docTypeArgs(ta.toList) { - case None => Doc.empty - case Some(k) => colonSpace + Kind.toDoc(k) - } + case None => Doc.empty + case Some(ta) => + TypeRef.docTypeArgs(ta.toList) { + case None => Doc.empty + case Some(k) => colonSpace + Kind.toDoc(k) + } } val argDoc = - Doc.intercalate(Doc.empty, + Doc.intercalate( + Doc.empty, args.toList.map { args => Doc.char('(') + Doc.intercalate( @@ -67,7 +69,9 @@ object DefStatement { ( Parser.keySpace( "def" - ) *> (Identifier.bindableParser ~ TypeRef.typeParams(kindAnnot.?).? ~ args.rep) <* maybeSpace, + ) *> (Identifier.bindableParser ~ TypeRef + .typeParams(kindAnnot.?) + .? ~ args.rep) <* maybeSpace, result.with1 <* (maybeSpace.with1 ~ P.char(':')), resultTParser ) diff --git a/core/src/main/scala/org/bykn/bosatsu/EditDistance.scala b/core/src/main/scala/org/bykn/bosatsu/EditDistance.scala index c1a23ca22..f224c0232 100644 --- a/core/src/main/scala/org/bykn/bosatsu/EditDistance.scala +++ b/core/src/main/scala/org/bykn/bosatsu/EditDistance.scala @@ -6,10 +6,10 @@ object EditDistance { def apply[A](a: Iterable[A], b: Iterable[A]): Int = a.foldLeft((0 to b.size).toList) { (prev, x) => (prev zip prev.tail zip b) - .scanLeft(prev.head + 1) { - case (h, ((d, v), y)) => min(min(h + 1, v + 1), d + (if (x == y) 0 else 1)) + .scanLeft(prev.head + 1) { case (h, ((d, v), y)) => + min(min(h + 1, v + 1), d + (if (x == y) 0 else 1)) } - }.last + }.last def string(a: String, b: String): Int = apply(a, b) diff --git a/core/src/main/scala/org/bykn/bosatsu/Evaluation.scala b/core/src/main/scala/org/bykn/bosatsu/Evaluation.scala index d24783b33..307361ef9 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Evaluation.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Evaluation.scala @@ -9,9 +9,9 @@ import cats.implicits._ import Identifier.Bindable case class Evaluation[T](pm: PackageMap.Typed[T], externals: Externals) { - /** - * Holds the final value of the environment for each Package - */ + + /** Holds the final value of the environment for each Package + */ private[this] val envCache: MMap[PackageName, Map[Identifier, Eval[Value]]] = MMap.empty @@ -20,46 +20,46 @@ case class Evaluation[T](pm: PackageMap.Typed[T], externals: Externals) { externalNames.iterator.map { n => val tpe = p.program.types.getValue(p.name, n) match { case Some(t) => t - case None => + case None => // $COVERAGE-OFF$ // should never happen due to typechecking sys.error(s"from ${p.name} import unknown external def: $n") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } externals.toMap.get((p.name, n.asString)) match { case Some(ext) => (n, Eval.later(ext.call(tpe))) - case None => + case None => // $COVERAGE-OFF$ // should never happen due to typechecking sys.error(s"from ${p.name} no External for external def: $n") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } - } - .toMap + }.toMap } private[this] lazy val gdr = pm.getDataRepr - private def evalLets(thisPack: PackageName, lets: List[(Bindable, RecursionKind, TypedExpr[T])]): List[(Bindable, Eval[Value])] = { + private def evalLets( + thisPack: PackageName, + lets: List[(Bindable, RecursionKind, TypedExpr[T])] + ): List[(Bindable, Eval[Value])] = { val exprs: List[(Bindable, Matchless.Expr)] = - rankn.RefSpace - .allocCounter + rankn.RefSpace.allocCounter .flatMap { c => lets - .traverse { - case (name, rec, te) => - Matchless.fromLet(name, rec, te, gdr, c) - .map((name, _)) + .traverse { case (name, rec, te) => + Matchless + .fromLet(name, rec, te, gdr, c) + .map((name, _)) } - } - .run - .value + } + .run + .value - val evalFn: (PackageName, Identifier) => Eval[Value] = - { (p, i) => - if (p == thisPack) Eval.defer(evaluate(p)(i)) - else evaluate(p)(i) - } + val evalFn: (PackageName, Identifier) => Eval[Value] = { (p, i) => + if (p == thisPack) Eval.defer(evaluate(p)(i)) + else evaluate(p)(i) + } type F[A] = List[(Bindable, A)] val ffunc = cats.Functor[List].compose(cats.Functor[(Bindable, *)]) @@ -67,10 +67,12 @@ case class Evaluation[T](pm: PackageMap.Typed[T], externals: Externals) { } private def evaluate(packName: PackageName): Map[Identifier, Eval[Value]] = - envCache.getOrElseUpdate(packName, { - val pack = pm.toMap(packName) - externalEnv(pack) ++ evalLets(packName, pack.program.lets) - }) + envCache.getOrElseUpdate( + packName, { + val pack = pm.toMap(packName) + externalEnv(pack) ++ evalLets(packName, pack.program.lets) + } + ) def evaluateLast(p: PackageName): Option[(Eval[Value], Type)] = for { @@ -80,18 +82,21 @@ case class Evaluation[T](pm: PackageMap.Typed[T], externals: Externals) { } yield (value, tpe.getType) // TODO: this only works for lets, not externals - def evaluateName(p: PackageName, name: Bindable): Option[(Eval[Value], Type)] = + def evaluateName( + p: PackageName, + name: Bindable + ): Option[(Eval[Value], Type)] = for { pack <- pm.toMap.get(p) - (_, _, tpe) <- pack.program.lets.filter { case (n, _, _) => n == name }.lastOption + (_, _, tpe) <- pack.program.lets.filter { case (n, _, _) => + n == name + }.lastOption value <- evaluate(p).get(name) } yield (value, tpe.getType) - /** - * Return the last test, if any, in the package. - * this is the test that is run when we test - * the package - */ + /** Return the last test, if any, in the package. this is the test that is run + * when we test the package + */ def lastTest(p: PackageName): Option[Eval[Value]] = for { pack <- pm.toMap.get(p) @@ -115,36 +120,30 @@ case class Evaluation[T](pm: PackageMap.Typed[T], externals: Externals) { Doc.intercalate(Doc.lineOrSpace, packs).render(80) } - */ + */ def evalTest(ps: PackageName): Option[Eval[Test]] = lastTest(ps).map { ea => ea.map(Test.fromValue(_)) } - /** - * Convert a typechecked value to Json - * this code ASSUMES the type is correct. If not, we may throw or return - * incorrect data. - */ - val valueToJson: ValueToJson = ValueToJson({ - case Type.Const.Defined(pn, t) => - for { - pack <- pm.toMap.get(pn) - dt <- pack.program.types.getType(pn, t) - } yield dt + /** Convert a typechecked value to Json this code ASSUMES the type is correct. + * If not, we may throw or return incorrect data. + */ + val valueToJson: ValueToJson = ValueToJson({ case Type.Const.Defined(pn, t) => + for { + pack <- pm.toMap.get(pn) + dt <- pack.program.types.getType(pn, t) + } yield dt }) - /** - * Convert a typechecked value to Doc - * this code ASSUMES the type is correct. If not, we may throw or return - * incorrect data. - */ - val valueToDoc: ValueToDoc = ValueToDoc({ - case Type.Const.Defined(pn, t) => - for { - pack <- pm.toMap.get(pn) - dt <- pack.program.types.getType(pn, t) - } yield dt + /** Convert a typechecked value to Doc this code ASSUMES the type is correct. + * If not, we may throw or return incorrect data. + */ + val valueToDoc: ValueToDoc = ValueToDoc({ case Type.Const.Defined(pn, t) => + for { + pack <- pm.toMap.get(pn) + dt <- pack.program.types.getType(pn, t) + } yield dt }) } diff --git a/core/src/main/scala/org/bykn/bosatsu/ExportedName.scala b/core/src/main/scala/org/bykn/bosatsu/ExportedName.scala index b8d855d1b..f5b94c67c 100644 --- a/core/src/main/scala/org/bykn/bosatsu/ExportedName.scala +++ b/core/src/main/scala/org/bykn/bosatsu/ExportedName.scala @@ -16,107 +16,120 @@ sealed abstract class ExportedName[+T] { self: Product => // we use them as hash keys final override val hashCode: Int = MurmurHash3.productHash(this) - /** - * Given name, in the current type environment and fully typed lets - * what does it correspond to? - */ + + /** Given name, in the current type environment and fully typed lets what does + * it correspond to? + */ private def toReferants[A]( - letValue: Option[rankn.Type], - definedType: Option[rankn.DefinedType[A]]): Option[NonEmptyList[ExportedName[Referant[A]]]] = - this match { - case ExportedName.Binding(n, _) => - letValue.map { tpe => - NonEmptyList.one(ExportedName.Binding(n, Referant.Value(tpe))) - } - case ExportedName.TypeName(nm, _) => - definedType.map { dt => - NonEmptyList.one(ExportedName.TypeName(nm, Referant.DefinedT(dt))) - } - case ExportedName.Constructor(nm, _) => - // export the type and all constructors - definedType.map { dt => - val cons = dt.constructors.map { cf => - ExportedName.Constructor(cf.name, Referant.Constructor(dt, cf)) - } - val t = ExportedName.TypeName(nm, Referant.DefinedT(dt)) - NonEmptyList(t, cons) - } - } + letValue: Option[rankn.Type], + definedType: Option[rankn.DefinedType[A]] + ): Option[NonEmptyList[ExportedName[Referant[A]]]] = + this match { + case ExportedName.Binding(n, _) => + letValue.map { tpe => + NonEmptyList.one(ExportedName.Binding(n, Referant.Value(tpe))) + } + case ExportedName.TypeName(nm, _) => + definedType.map { dt => + NonEmptyList.one(ExportedName.TypeName(nm, Referant.DefinedT(dt))) + } + case ExportedName.Constructor(nm, _) => + // export the type and all constructors + definedType.map { dt => + val cons = dt.constructors.map { cf => + ExportedName.Constructor(cf.name, Referant.Constructor(dt, cf)) + } + val t = ExportedName.TypeName(nm, Referant.DefinedT(dt)) + NonEmptyList(t, cons) + } + } } object ExportedName { - case class Binding[T](name: Identifier.Bindable, tag: T) extends ExportedName[T] - case class TypeName[T](name: Identifier.Constructor, tag: T) extends ExportedName[T] - case class Constructor[T](name: Identifier.Constructor, tag: T) extends ExportedName[T] + case class Binding[T](name: Identifier.Bindable, tag: T) + extends ExportedName[T] + case class TypeName[T](name: Identifier.Constructor, tag: T) + extends ExportedName[T] + case class Constructor[T](name: Identifier.Constructor, tag: T) + extends ExportedName[T] private[this] val consDoc = Doc.text("()") implicit val document: Document[ExportedName[Unit]] = { val di = Document[Identifier] Document.instance[ExportedName[Unit]] { - case Binding(n, _) => di.document(n) - case TypeName(n, _) => di.document(n) + case Binding(n, _) => di.document(n) + case TypeName(n, _) => di.document(n) case Constructor(n, _) => di.document(n) + consDoc } } val parser: P[ExportedName[Unit]] = - Identifier.bindableParser.map(Binding(_, ())) + Identifier.bindableParser + .map(Binding(_, ())) .orElse( (Identifier.consParser ~ P.string("()").?) .map { - case (n, None) => TypeName(n, ()) + case (n, None) => TypeName(n, ()) case (n, Some(_)) => Constructor(n, ()) } ) - /** - * Build exports into referants given a typeEnv - * The only error we have have here is if we name an export we didn't define - * Note a name can be two things: - * 1. a type - * 2. a value (e.g. a let or a constructor function) - */ + /** Build exports into referants given a typeEnv The only error we have have + * here is if we name an export we didn't define Note a name can be two + * things: + * 1. a type 2. a value (e.g. a let or a constructor function) + */ def buildExports[E, V, R, D]( - nm: PackageName, - exports: List[ExportedName[E]], - typeEnv: rankn.TypeEnv[V], - lets: List[(Identifier.Bindable, R, TypedExpr[D])])(implicit ev: V <:< Kind.Arg): ValidatedNel[ExportedName[E], List[ExportedName[Referant[V]]]] = { + nm: PackageName, + exports: List[ExportedName[E]], + typeEnv: rankn.TypeEnv[V], + lets: List[(Identifier.Bindable, R, TypedExpr[D])] + )(implicit + ev: V <:< Kind.Arg + ): ValidatedNel[ExportedName[E], List[ExportedName[Referant[V]]]] = { - val letMap = lets.iterator.map { case (n, _, t) => (n, t) }.toMap + val letMap = lets.iterator.map { case (n, _, t) => (n, t) }.toMap - def expName[A](ename: ExportedName[A]): Option[NonEmptyList[ExportedName[Referant[V]]]] = { - import ename.name - val letValue: Option[rankn.Type] = - name.toBindable - .flatMap { bn => - letMap.get(bn) - .map(_.getType) - .orElse { - // It could be an external or imported value in the TypeEnv - typeEnv.getValue(nm, bn) - } - } - val optDT = - name.toConstructor - .flatMap { cn => - typeEnv.getType(nm, org.bykn.bosatsu.TypeName(cn)) - } + def expName[A]( + ename: ExportedName[A] + ): Option[NonEmptyList[ExportedName[Referant[V]]]] = { + import ename.name + val letValue: Option[rankn.Type] = + name.toBindable + .flatMap { bn => + letMap + .get(bn) + .map(_.getType) + .orElse { + // It could be an external or imported value in the TypeEnv + typeEnv.getValue(nm, bn) + } + } + val optDT = + name.toConstructor + .flatMap { cn => + typeEnv.getType(nm, org.bykn.bosatsu.TypeName(cn)) + } - ename.toReferants(letValue, optDT) - } + ename.toReferants(letValue, optDT) + } - def expName1[A](ename: ExportedName[A]): ValidatedNel[ExportedName[A], List[ExportedName[Referant[V]]]] = - expName(ename) match { - case None => Validated.invalid(NonEmptyList.of(ename)) - case Some(v) => Validated.valid(v.toList) - } + def expName1[A]( + ename: ExportedName[A] + ): ValidatedNel[ExportedName[A], List[ExportedName[Referant[V]]]] = + expName(ename) match { + case None => Validated.invalid(NonEmptyList.of(ename)) + case Some(v) => Validated.valid(v.toList) + } - exports.traverse(expName1).map(_.flatten) + exports.traverse(expName1).map(_.flatten) } - def typeEnvFromExports[A](packageName: PackageName, exports: List[ExportedName[Referant[A]]]): TypeEnv[A] = + def typeEnvFromExports[A]( + packageName: PackageName, + exports: List[ExportedName[Referant[A]]] + ): TypeEnv[A] = exports.foldLeft((TypeEnv.empty): TypeEnv[A]) { (te, exp) => exp.tag.addTo(packageName, exp.name, te) } } - diff --git a/core/src/main/scala/org/bykn/bosatsu/Expr.scala b/core/src/main/scala/org/bykn/bosatsu/Expr.scala index 8a2ac6dd6..94bda1d2f 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Expr.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Expr.scala @@ -1,9 +1,8 @@ package org.bykn.bosatsu -/** - * This is a scala port of the example of Hindley Milner inference - * here: http://dev.stephendiehl.com/fun/006_hindley_milner.html - */ +/** This is a scala port of the example of Hindley Milner inference here: + * http://dev.stephendiehl.com/fun/006_hindley_milner.html + */ import cats.implicits._ import cats.data.{Chain, Writer, NonEmptyList} @@ -22,16 +21,36 @@ object Expr { case class Annotation[T](expr: Expr[T], tpe: Type, tag: T) extends Expr[T] case class Local[T](name: Bindable, tag: T) extends Name[T] - case class Generic[T](typeVars: NonEmptyList[(Type.Var.Bound, Kind)], in: Expr[T]) extends Expr[T] { + case class Generic[T]( + typeVars: NonEmptyList[(Type.Var.Bound, Kind)], + in: Expr[T] + ) extends Expr[T] { def tag = in.tag } - case class Global[T](pack: PackageName, name: Identifier, tag: T) extends Name[T] - case class App[T](fn: Expr[T], args: NonEmptyList[Expr[T]], tag: T) extends Expr[T] - case class Lambda[T](args: NonEmptyList[(Bindable, Option[Type])], expr: Expr[T], tag: T) extends Expr[T] - case class Let[T](arg: Bindable, expr: Expr[T], in: Expr[T], recursive: RecursionKind, tag: T) extends Expr[T] + case class Global[T](pack: PackageName, name: Identifier, tag: T) + extends Name[T] + case class App[T](fn: Expr[T], args: NonEmptyList[Expr[T]], tag: T) + extends Expr[T] + case class Lambda[T]( + args: NonEmptyList[(Bindable, Option[Type])], + expr: Expr[T], + tag: T + ) extends Expr[T] + case class Let[T]( + arg: Bindable, + expr: Expr[T], + in: Expr[T], + recursive: RecursionKind, + tag: T + ) extends Expr[T] case class Literal[T](lit: Lit, tag: T) extends Expr[T] - case class Match[T](arg: Expr[T], branches: NonEmptyList[(Pattern[(PackageName, Constructor), Type], Expr[T])], tag: T) extends Expr[T] - + case class Match[T]( + arg: Expr[T], + branches: NonEmptyList[ + (Pattern[(PackageName, Constructor), Type], Expr[T]) + ], + tag: T + ) extends Expr[T] def forAll[A](tpeArgs: List[(Type.Var.Bound, Kind)], expr: Expr[A]): Expr[A] = NonEmptyList.fromList(tpeArgs) match { @@ -46,29 +65,31 @@ object Expr { case Generic(typeVars, in) => Generic(nel ::: typeVars, in) case notAnn => Generic(nel, notAnn) - } + } } def quantifyFrees[A](expr: Expr[A]): Expr[A] = forAll(freeBoundTyVars(expr).map((_, Kind.Type)), expr) - /** - * Report all the Bindable names refered to in the given Expr. - * this can be used to allocate names that can never shadow - * anything being used in the expr - */ + /** Report all the Bindable names refered to in the given Expr. this can be + * used to allocate names that can never shadow anything being used in the + * expr + */ final def allNames[A](expr: Expr[A]): SortedSet[Bindable] = expr match { case Annotation(e, _, _) => allNames(e) - case Local(name, _) => SortedSet(name) - case Generic(_, in) => allNames(in) - case Global(_, _, _) => SortedSet.empty - case App(fn, args, _) => args.foldLeft(allNames(fn))((bs, e) => bs | allNames(e)) + case Local(name, _) => SortedSet(name) + case Generic(_, in) => allNames(in) + case Global(_, _, _) => SortedSet.empty + case App(fn, args, _) => + args.foldLeft(allNames(fn))((bs, e) => bs | allNames(e)) case Lambda(args, e, _) => allNames(e) ++ args.toList.iterator.map(_._1) case Let(arg, expr, in, _, _) => allNames(expr) | allNames(in) + arg - case Literal(_, _) => SortedSet.empty + case Literal(_, _) => SortedSet.empty case Match(exp, branches, _) => - allNames(exp) | branches.foldMap { case (pat, res) => allNames(res) ++ pat.names } + allNames(exp) | branches.foldMap { case (pat, res) => + allNames(res) ++ pat.names + } } implicit def hasRegion[T: HasRegion]: HasRegion[Expr[T]] = @@ -80,63 +101,83 @@ object Expr { private[this] val TruePat: Pattern[(PackageName, Constructor), Type] = Pattern.PositionalStruct((PackageName.PredefName, Constructor("True")), Nil) private[this] val FalsePat: Pattern[(PackageName, Constructor), Type] = - Pattern.PositionalStruct((PackageName.PredefName, Constructor("False")), Nil) - /** - * build a Match expression that is equivalent to if/else using Predef::True and Predef::False - */ - def ifExpr[T](cond: Expr[T], ifTrue: Expr[T], ifFalse: Expr[T], tag: T): Expr[T] = + Pattern.PositionalStruct( + (PackageName.PredefName, Constructor("False")), + Nil + ) + + /** build a Match expression that is equivalent to if/else using Predef::True + * and Predef::False + */ + def ifExpr[T]( + cond: Expr[T], + ifTrue: Expr[T], + ifFalse: Expr[T], + tag: T + ): Expr[T] = Match(cond, NonEmptyList.of((TruePat, ifTrue), (FalsePat, ifFalse)), tag) - /** - * Build an apply expression by appling these args left to right - */ + /** Build an apply expression by appling these args left to right + */ def buildApp[A](fn: Expr[A], args: List[Expr[A]], appTag: A): Expr[A] = args match { case head :: tail => App(fn, NonEmptyList(head, tail), appTag) - case Nil => fn + case Nil => fn } // Traverse all non-bound vars - private def traverseType[T, F[_]](expr: Expr[T], bound: Set[Type.Var.Bound])(fn: (Type, Set[Type.Var.Bound]) => F[Type])(implicit F: Applicative[F]): F[Expr[T]] = + private def traverseType[T, F[_]](expr: Expr[T], bound: Set[Type.Var.Bound])( + fn: (Type, Set[Type.Var.Bound]) => F[Type] + )(implicit F: Applicative[F]): F[Expr[T]] = expr match { case Annotation(e, tpe, a) => (traverseType(e, bound)(fn), fn(tpe, bound)).mapN(Annotation(_, _, a)) case v: Name[T] => F.pure(v) case App(f, args, t) => - (traverseType(f, bound)(fn), args.traverse(traverseType(_, bound)(fn))).mapN(App(_, _, t)) + (traverseType(f, bound)(fn), args.traverse(traverseType(_, bound)(fn))) + .mapN(App(_, _, t)) case Generic(bs, in) => // Seems dangerous since we are hiding from fn that the Type.TyVar inside // matching these are not unbound val bound1 = bound ++ bs.toList.iterator.map(_._1) traverseType(in, bound1)(fn).map(Generic(bs, _)) case Lambda(args, expr, t) => - (args.traverse { case (n, optT) => optT.traverse(fn(_, bound)).map((n, _)) }, - traverseType(expr, bound)(fn)).mapN(Lambda(_, _, t)) + ( + args.traverse { case (n, optT) => + optT.traverse(fn(_, bound)).map((n, _)) + }, + traverseType(expr, bound)(fn) + ).mapN(Lambda(_, _, t)) case Let(arg, exp, in, rec, tag) => - (traverseType(exp, bound)(fn), traverseType(in, bound)(fn)).mapN(Let(arg, _, _, rec, tag)) - case l@Literal(_, _) => F.pure(l) + (traverseType(exp, bound)(fn), traverseType(in, bound)(fn)) + .mapN(Let(arg, _, _, rec, tag)) + case l @ Literal(_, _) => F.pure(l) case Match(arg, branches, tag) => val argB = traverseType(arg, bound)(fn) type B = (Pattern[(PackageName, Constructor), Type], Expr[T]) def branchFn(b: B): F[B] = b match { case (pat, expr) => - pat.traverseType(fn(_, bound)) + pat + .traverseType(fn(_, bound)) .product(traverseType(expr, bound)(fn)) } val branchB = branches.traverse(branchFn _) (argB, branchB).mapN(Match(_, _, tag)) } - private def substExpr[A](keys: NonEmptyList[Type.Var], vals: NonEmptyList[Type.Rho], expr: Expr[A]): Expr[A] = { + private def substExpr[A]( + keys: NonEmptyList[Type.Var], + vals: NonEmptyList[Type.Rho], + expr: Expr[A] + ): Expr[A] = { val fn = Type.substTy(keys, vals) traverseType[A, cats.Id](expr, Set.empty) { (t, bound) => - // we have to remove any of the keys that are bound - val isBound: Type.Var => Boolean = - { - case b @ Type.Var.Bound(_) => bound(b) - case _ => false - } + // we have to remove any of the keys that are bound + val isBound: Type.Var => Boolean = { + case b @ Type.Var.Bound(_) => bound(b) + case _ => false + } if (keys.exists(isBound)) { val kv1 = keys.zip(vals).toList.filter { case (b, _) => !isBound(b) } @@ -147,8 +188,7 @@ object Expr { case None => t } - } - else fn(t) + } else fn(t) } } @@ -162,24 +202,25 @@ object Expr { w.written.iterator.toList.distinct } - /** - * Here we substitute any free bound variables with skolem variables - * - * This is a deviation from the paper. - * We are allowing a syntax like: - * - * def identity(x: a) -> a: - * x - * - * or: - * - * def foo(x: a): x - * - * We handle this by converting a to a skolem variable, - * running inference, then quantifying over that skolem - * variable. - */ - def skolemizeVars[F[_]: Applicative, A](vs: NonEmptyList[(Type.Var.Bound, Kind)], expr: Expr[A])(newSkolemTyVar: (Type.Var.Bound, Kind) => F[Type.Var.Skolem]): F[(NonEmptyList[Type.Var.Skolem], Expr[A])] = { + /** Here we substitute any free bound variables with skolem variables + * + * This is a deviation from the paper. We are allowing a syntax like: + * + * def identity(x: a) -> a: x + * + * or: + * + * def foo(x: a): x + * + * We handle this by converting a to a skolem variable, running inference, + * then quantifying over that skolem variable. + */ + def skolemizeVars[F[_]: Applicative, A]( + vs: NonEmptyList[(Type.Var.Bound, Kind)], + expr: Expr[A] + )( + newSkolemTyVar: (Type.Var.Bound, Kind) => F[Type.Var.Skolem] + ): F[(NonEmptyList[Type.Var.Skolem], Expr[A])] = { vs.traverse { case (b, k) => newSkolemTyVar(b, k) } .map { skVs => val sksT = skVs.map(Type.TyVar(_)) @@ -203,7 +244,9 @@ object Expr { Traverse[NonEmptyList].compose(tup) } - def traverse[G[_]: Applicative, A, B](fa: Expr[A])(f: A => G[B]): G[Expr[B]] = + def traverse[G[_]: Applicative, A, B]( + fa: Expr[A] + )(f: A => G[B]): G[Expr[B]] = fa match { case Annotation(e, tpe, a) => (e.traverse(f), f(a)).mapN(Annotation(_, tpe, _)) @@ -214,8 +257,9 @@ object Expr { case Generic(bs, e) => traverse(e)(f).map(Generic(bs, _)) case App(fn, args, t) => - (fn.traverse(f), args.traverse(_.traverse(f)), f(t)).mapN { (fn1, a1, b) => - App(fn1, a1, b) + (fn.traverse(f), args.traverse(_.traverse(f)), f(t)).mapN { + (fn1, a1, b) => + App(fn1, a1, b) } case Lambda(args, expr, t) => (expr.traverse(f), f(t)).mapN { (e1, t1) => @@ -261,7 +305,9 @@ object Expr { f(b2, tag) } - def foldRight[A, B](fa: Expr[A], lb: Eval[B])(f: (A, Eval[B]) => Eval[B]): Eval[B] = + def foldRight[A, B](fa: Expr[A], lb: Eval[B])( + f: (A, Eval[B]) => Eval[B] + ): Eval[B] = fa match { case Annotation(e, _, tag) => val lb1 = foldRight(e, lb)(f) @@ -289,17 +335,16 @@ object Expr { } def buildPatternLambda[A]( - args: NonEmptyList[Pattern[(PackageName, Constructor), Type]], - body: Expr[A], - outer: A): Expr[A] = { + args: NonEmptyList[Pattern[(PackageName, Constructor), Type]], + body: Expr[A], + outer: A + ): Expr[A] = { /* * compute this once if needed, which is why it is lazy. * we don't want to traverse body if it is never needed */ - lazy val anons = Type - .allBinders - .iterator + lazy val anons = Type.allBinders.iterator .map(_.name) .map(Identifier.Name(_)) .filterNot(allNames(body) ++ args.patternNames) @@ -327,4 +372,3 @@ object Expr { Lambda(justArgs, lambdaResult, outer) } } - diff --git a/core/src/main/scala/org/bykn/bosatsu/FfiCall.scala b/core/src/main/scala/org/bykn/bosatsu/FfiCall.scala index bbf03eaf0..85484d855 100644 --- a/core/src/main/scala/org/bykn/bosatsu/FfiCall.scala +++ b/core/src/main/scala/org/bykn/bosatsu/FfiCall.scala @@ -10,7 +10,9 @@ object FfiCall { final case class Fn1(fn: Value => Value) extends FfiCall { import Value.FnValue - private[this] val evalFn: FnValue = FnValue { case NonEmptyList(a, _) => fn(a) } + private[this] val evalFn: FnValue = FnValue { case NonEmptyList(a, _) => + fn(a) + } def call(t: rankn.Type): Value = evalFn } @@ -43,7 +45,7 @@ object FfiCall { def one(t: rankn.Type): Option[Class[_]] = loop(t, false) match { case c :: Nil => Some(c) - case _ => None + case _ => None } def loop(t: rankn.Type, top: Boolean): List[Class[_]] = { @@ -52,13 +54,15 @@ object FfiCall { val ats = as.map { a => one(a) match { case Some(at) => at - case function => sys.error(s"unsupported function type $function in $t") + case function => + sys.error(s"unsupported function type $function in $t") } } val res = one(b) match { case Some(at) => at - case function => sys.error(s"unsupported function type $function in $t") + case function => + sys.error(s"unsupported function type $function in $t") } ats.toList ::: res :: Nil case rankn.Type.ForAll(_, t) => @@ -69,4 +73,3 @@ object FfiCall { loop(t, true) } } - diff --git a/core/src/main/scala/org/bykn/bosatsu/Fix.scala b/core/src/main/scala/org/bykn/bosatsu/Fix.scala index f53b76465..3f00bccf8 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Fix.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Fix.scala @@ -1,11 +1,10 @@ package org.bykn.bosatsu final object FixType { - /** - * Use a trick in scala to give an opaque - * type for a fixed point recursion without - * having to allocate wrappers at each level - */ + + /** Use a trick in scala to give an opaque type for a fixed point recursion + * without having to allocate wrappers at each level + */ type Fix[F[_]] final def fix[F[_]](f: F[Fix[F]]): Fix[F] = diff --git a/core/src/main/scala/org/bykn/bosatsu/Identifier.scala b/core/src/main/scala/org/bykn/bosatsu/Identifier.scala index 1df6706a9..2b2bb1a6e 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Identifier.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Identifier.scala @@ -2,7 +2,7 @@ package org.bykn.bosatsu import cats.Order import cats.parse.{Parser0 => P0, Parser => P} -import org.typelevel.paiges.{ Doc, Document } +import org.typelevel.paiges.{Doc, Document} import Parser.{lowerIdent, upperIdent} @@ -26,21 +26,21 @@ sealed abstract class Identifier { def toBindable: Option[Identifier.Bindable] = this match { case b: Identifier.Bindable => Some(b) - case _ => None + case _ => None } def toConstructor: Option[Identifier.Constructor] = this match { case c: Identifier.Constructor => Some(c) - case _ => None + case _ => None } } object Identifier { - /** - * These are names that can appear in bindings. Importantly, - * we can't bind constructor names except to define types - */ + + /** These are names that can appear in bindings. Importantly, we can't bind + * constructor names except to define types + */ sealed abstract class Bindable extends Identifier final case class Constructor(asString: String) extends Identifier @@ -60,8 +60,8 @@ object Identifier { case Backticked(lit) => Doc.char('`') + Doc.text(Parser.escape('`', lit)) + Doc.char('`') case Constructor(n) => Doc.text(n) - case Name(n) => Doc.text(n) - case Operator(n) => opPrefix + Doc.text(n) + case Name(n) => Doc.text(n) + case Operator(n) => opPrefix + Doc.text(n) } val nameParser: P[Name] = @@ -70,25 +70,24 @@ object Identifier { val consParser: P[Constructor] = upperIdent.map { c => Constructor(c.intern) } - /** - * This is used to apply operators, it is the - * raw operator tokens without an `operator` prefix - */ + /** This is used to apply operators, it is the raw operator tokens without an + * `operator` prefix + */ val rawOperator: P[Operator] = Operators.operatorToken.map { op => Operator(op.intern) } - /** - * the keyword operator preceding a rawOperator - */ + /** the keyword operator preceding a rawOperator + */ val operator: P[Operator] = (P.string("operator").soft *> Parser.spaces) *> rawOperator - /** - * Name, Backticked or non-raw operator - */ + /** Name, Backticked or non-raw operator + */ val bindableParser: P[Bindable] = // operator has to come first to not look like a Name - P.oneOf(operator :: nameParser :: Parser.escapedString('`').map { b => Backticked(b.intern) } :: Nil) + P.oneOf(operator :: nameParser :: Parser.escapedString('`').map { b => + Backticked(b.intern) + } :: Nil) val parser: P[Identifier] = bindableParser.orElse(consParser) @@ -98,22 +97,20 @@ object Identifier { def appendToName(i: Bindable, suffix: String): Bindable = i match { case Backticked(b) => Backticked(b + suffix) - case _ => + case _ => // try to stry the same val p = operator.orElse(nameParser) val cand = i.sourceCodeRepr + suffix p.parseAll(cand) match { case Right(ident) => ident - case _ => + case _ => // just turn it into a Backticked Backticked(i.asString + suffix) } } - - /** - * Build an Identifier by parsing a string - */ + /** Build an Identifier by parsing a string + */ def unsafe(str: String): Identifier = unsafeParse(parser, str) diff --git a/core/src/main/scala/org/bykn/bosatsu/Import.scala b/core/src/main/scala/org/bykn/bosatsu/Import.scala index b3bf1be3a..d3728bee8 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Import.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Import.scala @@ -22,7 +22,9 @@ sealed abstract class ImportedName[+T] { ImportedName.Renamed(o, l, fn(t)) } - def traverse[F[_], U](fn: T => F[U])(implicit F: Functor[F]): F[ImportedName[U]] = + def traverse[F[_], U]( + fn: T => F[U] + )(implicit F: Functor[F]): F[ImportedName[U]] = this match { case ImportedName.OriginalName(n, t) => F.map(fn(t))(ImportedName.OriginalName(n, _)) @@ -32,10 +34,12 @@ sealed abstract class ImportedName[+T] { } object ImportedName { - case class OriginalName[T](originalName: Identifier, tag: T) extends ImportedName[T] { + case class OriginalName[T](originalName: Identifier, tag: T) + extends ImportedName[T] { def localName = originalName } - case class Renamed[T](originalName: Identifier, localName: Identifier, tag: T) extends ImportedName[T] + case class Renamed[T](originalName: Identifier, localName: Identifier, tag: T) + extends ImportedName[T] implicit val document: Document[ImportedName[Unit]] = Document.instance[ImportedName[Unit]] { @@ -50,7 +54,7 @@ object ImportedName { (of ~ (spaces.soft *> P.string("as") *> spaces *> of).?) .map { case (from, Some(to)) => ImportedName.Renamed(from, to, ()) - case (orig, None) => ImportedName.OriginalName(orig, ()) + case (orig, None) => ImportedName.OriginalName(orig, ()) } basedOn(Identifier.bindableParser) @@ -60,8 +64,9 @@ object ImportedName { case class Import[A, B](pack: A, items: NonEmptyList[ImportedName[B]]) { def resolveToGlobal: Map[Identifier, (A, Identifier)] = - items.foldLeft(Map.empty[Identifier, (A, Identifier)]) { case (m0, impName) => - m0.updated(impName.localName, (pack, impName.originalName)) + items.foldLeft(Map.empty[Identifier, (A, Identifier)]) { + case (m0, impName) => + m0.updated(impName.localName, (pack, impName.originalName)) } } @@ -70,7 +75,9 @@ object Import { Document.instance[Import[PackageName, Unit]] { case Import(pname, items) => val itemDocs = items.toList.map(Document[ImportedName[Unit]].document _) - Doc.text("from") + Doc.space + Document[PackageName].document(pname) + Doc.space + Doc.text("import") + + Doc.text("from") + Doc.space + Document[PackageName].document( + pname + ) + Doc.space + Doc.text("import") + // TODO: use paiges to pack this in nicely using .group or something Doc.space + Doc.intercalate(Doc.text(", "), itemDocs) } @@ -78,21 +85,24 @@ object Import { val parser: P[Import[PackageName, Unit]] = { val pyimps = ImportedName.parser.itemsMaybeParens.map(_._2) - ((P.string("from") ~ spaces).backtrack *> PackageName.parser <* spaces, - P.string("import") *> spaces *> pyimps) + ( + (P.string("from") ~ spaces).backtrack *> PackageName.parser <* spaces, + P.string("import") *> spaces *> pyimps + ) .mapN(Import(_, _)) } - /** - * This only keeps the last name if there are duplicate local names - * checking for duplicate local names should be done at another layer - */ - def locals[F[_]: Foldable, A, B, C](imp: Import[A, F[B]])(pn: PartialFunction[B, C]): Map[Identifier, C] = { + /** This only keeps the last name if there are duplicate local names checking + * for duplicate local names should be done at another layer + */ + def locals[F[_]: Foldable, A, B, C]( + imp: Import[A, F[B]] + )(pn: PartialFunction[B, C]): Map[Identifier, C] = { val fn = pn.lift imp.items.foldLeft(Map.empty[Identifier, C]) { case (m0, impName) => impName.tag.foldLeft(m0) { (m1, b) => fn(b) match { - case None => m1 + case None => m1 case Some(c) => m1.updated(impName.localName, c) } } @@ -100,9 +110,8 @@ object Import { } } -/** - * There are all the distinct imported names and the original ImportedName - */ +/** There are all the distinct imported names and the original ImportedName + */ case class ImportMap[A, B](toMap: Map[Identifier, (A, ImportedName[B])]) { def apply(name: Identifier): Option[(A, ImportedName[B])] = toMap.get(name) @@ -115,16 +124,18 @@ object ImportMap { def empty[A, B]: ImportMap[A, B] = ImportMap(Map.empty) // Return the list of collisions in local names along with a map // with the last name overwriting the import - def fromImports[A, B](is: List[Import[A, B]]): (List[(A, ImportedName[B])], ImportMap[A, B]) = + def fromImports[A, B]( + is: List[Import[A, B]] + ): (List[(A, ImportedName[B])], ImportMap[A, B]) = is.iterator .flatMap { case Import(p, is) => is.toList.iterator.map((p, _)) } .foldLeft((List.empty[(A, ImportedName[B])], ImportMap.empty[A, B])) { - case ((dups, imap), pim@(_, im)) => + case ((dups, imap), pim @ (_, im)) => val dups1 = imap(im.localName) match { case Some(nm) => nm :: dups - case None => dups + case None => dups } (dups1, imap + pim) - } + } } diff --git a/core/src/main/scala/org/bykn/bosatsu/Indented.scala b/core/src/main/scala/org/bykn/bosatsu/Indented.scala index 4a32c4ebb..107fbf363 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Indented.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Indented.scala @@ -1,6 +1,6 @@ package org.bykn.bosatsu -import org.typelevel.paiges.{ Doc, Document } +import org.typelevel.paiges.{Doc, Document} import cats.parse.{Parser => P} @@ -13,9 +13,9 @@ case class Indented[T](spaces: Int, value: T) { object Indented { def spaceCount(str: String): Int = str.foldLeft(0) { - case (s, ' ') => s + 1 + case (s, ' ') => s + 1 case (s, '\t') => s + 4 - case (_, c) => sys.error(s"unexpected space character($c) in $str") + case (_, c) => sys.error(s"unexpected space character($c) in $str") } implicit def document[T: Document]: Document[Indented[T]] = @@ -23,14 +23,11 @@ object Indented { Doc.spaces(i) + (Document[T].document(t).nested(i)) } - - /** - * This reads a new line at a deeper indentation level - * than we currently are. - * - * So we are starting from the 0 column and read - * the current indentation level plus at least one space more - */ + /** This reads a new line at a deeper indentation level than we currently are. + * + * So we are starting from the 0 column and read the current indentation + * level plus at least one space more + */ def indy[T](p: Parser.Indy[T]): Parser.Indy[Indented[T]] = Parser.Indy { indent => for { @@ -39,4 +36,3 @@ object Indented { } yield Indented(Indented.spaceCount(thisIndent), t) } } - diff --git a/core/src/main/scala/org/bykn/bosatsu/IorMethods.scala b/core/src/main/scala/org/bykn/bosatsu/IorMethods.scala index af335f390..a8d4c0e7d 100644 --- a/core/src/main/scala/org/bykn/bosatsu/IorMethods.scala +++ b/core/src/main/scala/org/bykn/bosatsu/IorMethods.scala @@ -6,8 +6,8 @@ object IorMethods { implicit class IorExtension[A, B](val ior: Ior[A, B]) extends AnyVal { def strictToValidated: Validated[A, B] = ior match { - case Ior.Right(b) => Validated.valid(b) - case Ior.Left(a) => Validated.invalid(a) + case Ior.Right(b) => Validated.valid(b) + case Ior.Left(a) => Validated.invalid(a) case Ior.Both(a, _) => Validated.invalid(a) } } diff --git a/core/src/main/scala/org/bykn/bosatsu/Json.scala b/core/src/main/scala/org/bykn/bosatsu/Json.scala index 159a56a22..94d549277 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Json.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Json.scala @@ -5,9 +5,8 @@ import org.typelevel.paiges.Doc import cats.parse.{Parser0 => P0, Parser => P} import cats.Eq -/** - * A simple JSON ast for output - */ +/** A simple JSON ast for output + */ sealed abstract class Json { def toDoc: Doc @@ -45,7 +44,7 @@ object Json { } def unapply(j: Json): Option[BigInteger] = j match { - case num@JNumberStr(str) => + case num @ JNumberStr(str) => if (allDigits(str)) Some(new BigInteger(str)) else num.toBigInteger case _ => None @@ -70,9 +69,9 @@ object Json { def unapply(j: Json): Option[Boolean] = j match { - case True => someTrue + case True => someTrue case False => someFalse - case _ => None + case _ => None } } @@ -83,7 +82,10 @@ object Json { } final case class JArray(toVector: Vector[Json]) extends Json { def toDoc = { - val parts = Doc.intercalate(Doc.comma, toVector.map { j => (Doc.line + j.toDoc).grouped }) + val parts = Doc.intercalate( + Doc.comma, + toVector.map { j => (Doc.line + j.toDoc).grouped } + ) "[" +: ((parts :+ " ]").nested(2)) } @@ -104,56 +106,54 @@ object Json { parts.bracketBy(text("{"), text("}")) } - /** - * Return a JObject with each key at most once, but in the order of this - */ + /** Return a JObject with each key at most once, but in the order of this + */ def normalize: JObject = JObject(keys.map { k => (k, toMap(k)) }) def render = toDoc.render(80) } - /** - * this checks for semantic equivalence: - * 1. we use BigDecimal to compare JNumberStr - * 2. we normalize objects - */ + /** this checks for semantic equivalence: + * 1. we use BigDecimal to compare JNumberStr 2. we normalize objects + */ implicit val eqJson: Eq[Json] = new Eq[Json] { def eqv(a: Json, b: Json) = (a, b) match { - case (JNull, JNull) => true - case (JBool.True, JBool.True) => true + case (JNull, JNull) => true + case (JBool.True, JBool.True) => true case (JBool.False, JBool.False) => true case (JString(sa), JString(sb)) => sa == sb case (JNumberStr(sa), JNumberStr(sb)) => new BigDecimal(sa).compareTo(new BigDecimal(sb)) == 0 case (JArray(itemsa), JArray(itemsb)) => (itemsa.size == itemsb.size) && - itemsa.iterator - .zip(itemsb.iterator) - .forall { case (a, b) => eqv(a, b) } - case (oa@JObject(_), ob@JObject(_)) => + itemsa.iterator + .zip(itemsb.iterator) + .forall { case (a, b) => eqv(a, b) } + case (oa @ JObject(_), ob @ JObject(_)) => val na = oa.normalize val nb = ob.normalize (na.toMap.keySet == nb.toMap.keySet) && - na.keys.forall { k => - eqv(na.toMap(k), nb.toMap(k)) - } + na.keys.forall { k => + eqv(na.toMap(k), nb.toMap(k)) + } case (_, _) => false } } private[this] val whitespace: P[Unit] = P.charIn(" \t\r\n").void private[this] val whitespaces0: P0[Unit] = whitespace.rep0.void - /** - * This doesn't have to be super fast (but is fairly fast) since we use it in places - * where speed won't matter: feeding it into a program that will convert it to bosatsu - * structured data - */ + + /** This doesn't have to be super fast (but is fairly fast) since we use it in + * places where speed won't matter: feeding it into a program that will + * convert it to bosatsu structured data + */ val parser: P[Json] = { val recurse = P.defer(parser) val pnull = P.string("null").as(JNull) - val bool = P.string("true").as(JBool.True).orElse(P.string("false").as(JBool.False)) + val bool = + P.string("true").as(JBool.True).orElse(P.string("false").as(JBool.False)) val justStr = JsonStringUtil.escapedString('"') val str = justStr.map(JString(_)) val num = Parser.JsonNumber.parser.map(JNumberStr(_)) @@ -177,6 +177,7 @@ object Json { } // any whitespace followed by json followed by whitespace followed by end - val parserFile: P[Json] = whitespaces0.with1 *> (parser ~ whitespaces0 ~ P.end).map(_._1._1) + val parserFile: P[Json] = + whitespaces0.with1 *> (parser ~ whitespaces0 ~ P.end).map(_._1._1) } diff --git a/core/src/main/scala/org/bykn/bosatsu/Kind.scala b/core/src/main/scala/org/bykn/bosatsu/Kind.scala index 9d8a37ff1..4ed3c0e5d 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Kind.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Kind.scala @@ -101,9 +101,11 @@ object Kind { case _ => false } - def validApply[A](left: Kind, right: Kind, onTypeErr: => A)(onSubsumeFail: Cons => A): Either[A, Kind] = + def validApply[A](left: Kind, right: Kind, onTypeErr: => A)( + onSubsumeFail: Cons => A + ): Either[A, Kind] = left match { - case cons@Cons(Kind.Arg(_, lhs), res) => + case cons @ Cons(Kind.Arg(_, lhs), res) => if (leftSubsumesRight(lhs, right)) Right(res) else Left(onSubsumeFail(cons)) case Kind.Type => Left(onTypeErr) diff --git a/core/src/main/scala/org/bykn/bosatsu/ListLang.scala b/core/src/main/scala/org/bykn/bosatsu/ListLang.scala index eb486a054..c19c1254b 100644 --- a/core/src/main/scala/org/bykn/bosatsu/ListLang.scala +++ b/core/src/main/scala/org/bykn/bosatsu/ListLang.scala @@ -7,10 +7,9 @@ import cats.parse.{Parser => P} import cats.implicits._ -/** - * Represents the list construction sublanguage - * A is the expression type, B is the pattern type for bindings - */ +/** Represents the list construction sublanguage A is the expression type, B is + * the pattern type for bindings + */ sealed abstract class ListLang[F[_], A, +B] object ListLang { sealed abstract class SpliceOrItem[A] { @@ -37,10 +36,12 @@ object ListLang { .map(Splice(_)) .orElse(pa.map(Item(_))) - implicit def document[A](implicit A: Document[A]): Document[SpliceOrItem[A]] = + implicit def document[A](implicit + A: Document[A] + ): Document[SpliceOrItem[A]] = Document.instance[SpliceOrItem[A]] { case Splice(a) => Doc.char('*') + A.document(a) - case Item(a) => A.document(a) + case Item(a) => A.document(a) } } @@ -58,38 +59,58 @@ object ListLang { .map { case (k, v) => KVPair(k, v) } implicit def document[A](implicit A: Document[A]): Document[KVPair[A]] = - Document.instance[KVPair[A]] { - case KVPair(k, v) => A.document(k) + sep + A.document(v) + Document.instance[KVPair[A]] { case KVPair(k, v) => + A.document(k) + sep + A.document(v) } } case class Cons[F[_], A](items: List[F[A]]) extends ListLang[F, A, Nothing] - case class Comprehension[F[_], A, B](expr: F[A], binding: B, in: A, filter: Option[A]) extends ListLang[F, A, B] - - def parser[A, B](pa: P[A], psrc: P[A], pbind: P[B]): P[ListLang[SpliceOrItem, A, B]] = + case class Comprehension[F[_], A, B]( + expr: F[A], + binding: B, + in: A, + filter: Option[A] + ) extends ListLang[F, A, B] + + def parser[A, B]( + pa: P[A], + psrc: P[A], + pbind: P[B] + ): P[ListLang[SpliceOrItem, A, B]] = genParser(P.char('['), SpliceOrItem.parser(pa), psrc, pbind, P.char(']')) - def dictParser[A, B](pa: P[A], psrc: P[A], pbind: P[B]): P[ListLang[KVPair, A, B]] = + def dictParser[A, B]( + pa: P[A], + psrc: P[A], + pbind: P[B] + ): P[ListLang[KVPair, A, B]] = genParser(P.char('{'), KVPair.parser(pa), psrc, pbind, P.char('}')) - def genParser[F[_], A, B](left: P[Unit], fa: P[F[A]], pa: P[A], pbind: P[B], right: P[Unit]): P[ListLang[F, A, B]] = { + def genParser[F[_], A, B]( + left: P[Unit], + fa: P[F[A]], + pa: P[A], + pbind: P[B], + right: P[Unit] + ): P[ListLang[F, A, B]] = { // construct the tail of a list, so we will finally have at least one item - val consTail = fa.nonEmptyListOfWs(maybeSpacesAndLines).? - .map { tail => - val listTail = tail match { - case None => Nil - case Some(ne) => ne.toList - } - - { (a: F[A]) => Cons(a :: listTail) } + val consTail = fa.nonEmptyListOfWs(maybeSpacesAndLines).?.map { tail => + val listTail = tail match { + case None => Nil + case Some(ne) => ne.toList } + { (a: F[A]) => Cons(a :: listTail) } + } + val filterExpr = P.string("if") *> spacesAndLines *> pa val comp = - (P.string("for") *> spacesAndLines *> pbind <* maybeSpacesAndLines, - P.string("in") *> spacesAndLines *> pa <* maybeSpacesAndLines, - filterExpr.?) + ( + P.string("for") *> spacesAndLines *> pbind <* maybeSpacesAndLines, + P.string("in") *> spacesAndLines *> pa <* maybeSpacesAndLines, + filterExpr.? + ) .mapN { (b, i, f) => { (e: F[A]) => Comprehension(e, b, i, f) } } @@ -99,21 +120,24 @@ object ListLang { (left *> maybeSpacesAndLines *> (fa ~ inner.?).? <* maybeSpacesAndLines <* right) .map { - case None => Cons(Nil) - case Some((a, None)) => Cons(a :: Nil) + case None => Cons(Nil) + case Some((a, None)) => Cons(a :: Nil) case Some((a, Some(rest))) => rest(a) } } - def genDocument[F[_], A, B](left: Doc, right: Doc)(implicit F: Document[F[A]], A: Document[A], B: Document[B]): Document[ListLang[F, A, B]] = + def genDocument[F[_], A, B](left: Doc, right: Doc)(implicit + F: Document[F[A]], + A: Document[A], + B: Document[B] + ): Document[ListLang[F, A, B]] = Document.instance[ListLang[F, A, B]] { case Cons(items) => - left + Doc.intercalate(Doc.text(", "), - items.map(F.document(_))) + + left + Doc.intercalate(Doc.text(", "), items.map(F.document(_))) + right case Comprehension(e, b, i, f) => val filt = f match { - case None => Doc.empty + case None => Doc.empty case Some(e) => Doc.text(" if ") + A.document(e) } left + F.document(e) + Doc.text(" for ") + @@ -122,10 +146,15 @@ object ListLang { right } - implicit def document[A, B](implicit A: Document[A], B: Document[B]): Document[ListLang[SpliceOrItem, A, B]] = + implicit def document[A, B](implicit + A: Document[A], + B: Document[B] + ): Document[ListLang[SpliceOrItem, A, B]] = genDocument[SpliceOrItem, A, B](Doc.char('['), Doc.char(']')) - implicit def documentDict[A, B](implicit A: Document[A], B: Document[B]): Document[ListLang[KVPair, A, B]] = + implicit def documentDict[A, B](implicit + A: Document[A], + B: Document[B] + ): Document[ListLang[KVPair, A, B]] = genDocument[KVPair, A, B](Doc.char('{'), Doc.char('}')) } - diff --git a/core/src/main/scala/org/bykn/bosatsu/Lit.scala b/core/src/main/scala/org/bykn/bosatsu/Lit.scala index 2213f515e..527028956 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Lit.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Lit.scala @@ -10,7 +10,7 @@ sealed abstract class Lit { def repr: String = this match { case Lit.Integer(i) => i.toString - case Lit.Str(s) => "\"" + escape('"', s) + "\"" + case Lit.Str(s) => "\"" + escape('"', s) + "\"" } def unboxToAny: Any @@ -31,7 +31,9 @@ object Lit { def apply(str: String): Lit = Str(str) val integerParser: P[Integer] = - Parser.integerString.map { str => Integer(new BigInteger(str.filterNot(_ == '_'))) } + Parser.integerString.map { str => + Integer(new BigInteger(str.filterNot(_ == '_'))) + } val stringParser: P[Str] = { val q1 = '\'' @@ -47,9 +49,9 @@ object Lit { def compare(a: Lit, b: Lit): Int = (a, b) match { case (Integer(a), Integer(b)) => a.compareTo(b) - case (Integer(_), Str(_)) => -1 - case (Str(_), Integer(_)) => 1 - case (Str(a), Str(b)) => a.compareTo(b) + case (Integer(_), Str(_)) => -1 + case (Str(_), Integer(_)) => 1 + case (Str(a), Str(b)) => a.compareTo(b) } } @@ -64,4 +66,3 @@ object Lit { Doc.char(q) + Doc.text(escape(q, str)) + Doc.char(q) } } - diff --git a/core/src/main/scala/org/bykn/bosatsu/LocationMap.scala b/core/src/main/scala/org/bykn/bosatsu/LocationMap.scala index baffe6e09..9a5caa101 100644 --- a/core/src/main/scala/org/bykn/bosatsu/LocationMap.scala +++ b/core/src/main/scala/org/bykn/bosatsu/LocationMap.scala @@ -7,15 +7,13 @@ import cats.implicits._ import LocationMap.Colorize -/** - * Build a cache of the rows and columns in a given - * string. This is for showing error messages to users - */ +/** Build a cache of the rows and columns in a given string. This is for showing + * error messages to users + */ case class LocationMap(fromString: String) extends CPLocationMap(fromString) { private def lineRange(start: Int, end: Int): List[(Int, String)] = - (start to end) - .iterator + (start to end).iterator .filter(_ >= 0) .map { r => val liner = getLine(r).get // should never throw @@ -24,10 +22,9 @@ case class LocationMap(fromString: String) extends CPLocationMap(fromString) { } .toList - /** - * convert tab to tab, but otherwise space - * return the white space before this column - */ + /** convert tab to tab, but otherwise space return the white space before this + * column + */ private def spaceOf(row: Int, col: Int): Option[String] = getLine(row) .map { line => @@ -42,7 +39,11 @@ case class LocationMap(fromString: String) extends CPLocationMap(fromString) { bldr.toString() } - def showContext(offset: Int, previousLines: Int, color: Colorize): Option[Doc] = + def showContext( + offset: Int, + previousLines: Int, + color: Colorize + ): Option[Doc] = toLineCol(offset) .map { case (r, c) => val lines = lineRange(r - previousLines, r) @@ -60,10 +61,17 @@ case class LocationMap(fromString: String) extends CPLocationMap(fromString) { val ctx = Doc.intercalate(Doc.hardLine, lineDocs) // convert to spaces val colPad = spaceOf(r, c).get - ctx + Doc.hardLine + pointerPad + LocationMap.pointerTo(colPad, color) + Doc.hardLine + ctx + Doc.hardLine + pointerPad + LocationMap.pointerTo( + colPad, + color + ) + Doc.hardLine } - def showRegion(region: Region, previousLines: Int, color: Colorize): Option[Doc] = + def showRegion( + region: Region, + previousLines: Int, + color: Colorize + ): Option[Doc] = (toLineCol(region.start), toLineCol(region.end - 1)) .mapN { case ((l0, c0), (l1, c1)) => val lines = lineRange(l0 - previousLines, l1) @@ -78,14 +86,19 @@ case class LocationMap(fromString: String) extends CPLocationMap(fromString) { // same line // here is how much extra we need for the pointer val pointerPad = Doc.spaces(toLineStr(l0).render(0).length) - val lineDocs = lines.map { case (no, l) => toLineStr(no) + Doc.text(l) } + val lineDocs = lines.map { case (no, l) => + toLineStr(no) + Doc.text(l) + } val ctx = Doc.intercalate(Doc.hardLine, lineDocs) val c0Pad = spaceOf(l0, c0).get // we go one more to cover the column val c1Pad = spaceOf(l0, c1 + 1).get - ctx + Doc.hardLine + pointerPad + LocationMap.pointerRange(c0Pad, c1Pad, color) + Doc.hardLine - } - else { + ctx + Doc.hardLine + pointerPad + LocationMap.pointerRange( + c0Pad, + c1Pad, + color + ) + Doc.hardLine + } else { // we span multiple lines, show the start and the end: val newPrev = l1 - l0 showContext(region.start, previousLines, color).get + @@ -110,26 +123,32 @@ object LocationMap { object Console extends Colorize { def red(d: Doc) = - Doc.zeroWidth(scala.Console.RED) + d.unzero + Doc.zeroWidth(scala.Console.RESET) + Doc.zeroWidth(scala.Console.RED) + d.unzero + Doc.zeroWidth( + scala.Console.RESET + ) def green(d: Doc) = - Doc.zeroWidth(scala.Console.GREEN) + d.unzero + Doc.zeroWidth(scala.Console.RESET) + Doc.zeroWidth(scala.Console.GREEN) + d.unzero + Doc.zeroWidth( + scala.Console.RESET + ) } object HmtlFont extends Colorize { def red(d: Doc) = - Doc.zeroWidth("") + d.unzero + Doc.zeroWidth("") + Doc.zeroWidth("") + d.unzero + Doc.zeroWidth( + "" + ) def green(d: Doc) = - Doc.zeroWidth("") + d.unzero + Doc.zeroWidth("") + Doc.zeroWidth("") + d.unzero + Doc.zeroWidth( + "" + ) } } - /** - * Provide a string that points with a carat to a given column - * with 0 based indexing: - * e.g. pointerTo(2) == " ^" - */ + /** Provide a string that points with a carat to a given column with 0 based + * indexing: e.g. pointerTo(2) == " ^" + */ def pointerTo(colStr: String, color: Colorize): Doc = { val col = Doc.text(colStr) val pointer = Doc.char('^') @@ -141,7 +160,7 @@ object LocationMap { // just use tab for any tabs val pointerStr = endPad.drop(startPad.length).map { case '\t' => '\t' - case _ => '^' + case _ => '^' } val pointer = Doc.text(pointerStr) col + color.red(pointer) diff --git a/core/src/main/scala/org/bykn/bosatsu/Matchless.scala b/core/src/main/scala/org/bykn/bosatsu/Matchless.scala index 28c65444c..3201de9ce 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Matchless.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Matchless.scala @@ -26,13 +26,22 @@ object Matchless { // we should probably allocate static slots for each bindable, // and replace the local with an integer offset slot access for // the closure state - case class Lambda(captures: List[Bindable], args: NonEmptyList[Bindable], expr: Expr) extends FnExpr + case class Lambda( + captures: List[Bindable], + args: NonEmptyList[Bindable], + expr: Expr + ) extends FnExpr // this is a tail recursive function that should be compiled into a loop // when a call to name is done inside body, that should restart the loop // the type of this Expr a function with the arity of args that returns // the type of body - case class LoopFn(captures: List[Bindable], name: Bindable, arg: NonEmptyList[Bindable], body: Expr) extends FnExpr + case class LoopFn( + captures: List[Bindable], + name: Bindable, + arg: NonEmptyList[Bindable], + body: Expr + ) extends FnExpr case class Global(pack: PackageName, name: Bindable) extends CheapExpr @@ -46,7 +55,11 @@ object Matchless { // we aggregate all the applications to potentially make dispatch more efficient // note fn is never an App case class App(fn: Expr, arg: NonEmptyList[Expr]) extends Expr - case class Let(arg: Either[LocalAnon, (Bindable, RecursionKind)], expr: Expr, in: Expr) extends Expr + case class Let( + arg: Either[LocalAnon, (Bindable, RecursionKind)], + expr: Expr, + in: Expr + ) extends Expr case class LetMut(name: LocalAnonMut, span: Expr) extends Expr case class Literal(lit: Lit) extends CheapExpr @@ -58,7 +71,7 @@ object Matchless { (this, that) match { case (TrueConst, r) => r case (l, TrueConst) => l - case _ => And(this, that) + case _ => And(this, that) } } // returns 1 if it does, else 0 @@ -68,16 +81,30 @@ object Matchless { case class And(e1: BoolExpr, e2: BoolExpr) extends BoolExpr // checks if variant matches, and if so, writes to // a given mut - case class CheckVariant(expr: CheapExpr, expect: Int, size: Int, famArities: List[Int]) extends BoolExpr + case class CheckVariant( + expr: CheapExpr, + expect: Int, + size: Int, + famArities: List[Int] + ) extends BoolExpr // handle list matching, this is a while loop, that is evaluting // lst is initialized to init, leftAcc is initialized to empty // tail until it is true while mutating lst => lst.tail // this has the side-effect of mutating lst and leftAcc as well as any side effects that check has // which could have nested searches of its own - case class SearchList(lst: LocalAnonMut, init: CheapExpr, check: BoolExpr, leftAcc: Option[LocalAnonMut]) extends BoolExpr + case class SearchList( + lst: LocalAnonMut, + init: CheapExpr, + check: BoolExpr, + leftAcc: Option[LocalAnonMut] + ) extends BoolExpr // set the mutable variable to the given expr and return true // string matching is complex done at a lower level - case class MatchString(arg: CheapExpr, parts: List[StrPart], binds: List[LocalAnonMut]) extends BoolExpr + case class MatchString( + arg: CheapExpr, + parts: List[StrPart], + binds: List[LocalAnonMut] + ) extends BoolExpr // set the mutable variable to the given expr and return true case class SetMut(target: LocalAnonMut, expr: Expr) extends BoolExpr case object TrueConst extends BoolExpr @@ -85,9 +112,11 @@ object Matchless { def hasSideEffect(bx: BoolExpr): Boolean = bx match { case SetMut(_, _) => true - case TrueConst | CheckVariant(_, _, _, _) | EqualsLit(_, _) | EqualsNat(_, _) => false + case TrueConst | CheckVariant(_, _, _, _) | EqualsLit(_, _) | + EqualsNat(_, _) => + false case MatchString(_, _, b) => b.nonEmpty - case And(b1, b2) => hasSideEffect(b1) || hasSideEffect(b2) + case And(b1, b2) => hasSideEffect(b1) || hasSideEffect(b2) case SearchList(_, _, b, l) => l.nonEmpty || hasSideEffect(b) } @@ -98,18 +127,20 @@ object Matchless { if (hasSideEffect(cond)) Always(cond, thenExpr) else thenExpr - /** - * These aren't really super cheap, but when we treat them cheap we check that we will only - * call them one time - */ - case class GetEnumElement(arg: CheapExpr, variant: Int, index: Int, size: Int) extends CheapExpr - case class GetStructElement(arg: CheapExpr, index: Int, size: Int) extends CheapExpr + /** These aren't really super cheap, but when we treat them cheap we check + * that we will only call them one time + */ + case class GetEnumElement(arg: CheapExpr, variant: Int, index: Int, size: Int) + extends CheapExpr + case class GetStructElement(arg: CheapExpr, index: Int, size: Int) + extends CheapExpr sealed abstract class ConsExpr extends Expr { def arity: Int } // we need to compile calls to constructors into these - case class MakeEnum(variant: Int, arity: Int, famArities: List[Int]) extends ConsExpr + case class MakeEnum(variant: Int, arity: Int, famArities: List[Int]) + extends ConsExpr case class MakeStruct(arity: Int) extends ConsExpr case object ZeroNat extends ConsExpr { def arity = 0 @@ -124,25 +155,27 @@ object Matchless { private def asCheap(expr: Expr): Option[CheapExpr] = expr match { case c: CheapExpr => Some(c) - case _ => None + case _ => None } - private def maybeMemo[F[_]: Monad](tmp: F[Long])(fn: CheapExpr => F[Expr]): Expr => F[Expr] = - { (arg: Expr) => - asCheap(arg) match { - case Some(c) => fn(c) - case None => - for { - nm <- tmp - bound = LocalAnon(nm) - res <- fn(bound) - } yield Let(Left(bound), arg, res) - } + private def maybeMemo[F[_]: Monad]( + tmp: F[Long] + )(fn: CheapExpr => F[Expr]): Expr => F[Expr] = { (arg: Expr) => + asCheap(arg) match { + case Some(c) => fn(c) + case None => + for { + nm <- tmp + bound = LocalAnon(nm) + res <- fn(bound) + } yield Let(Left(bound), arg, res) } + } private[this] val empty = (PackageName.PredefName, Constructor("EmptyList")) private[this] val cons = (PackageName.PredefName, Constructor("NonEmptyList")) - private[this] val reverseFn = Global(PackageName.PredefName, Identifier.Name("reverse")) + private[this] val reverseFn = + Global(PackageName.PredefName, Identifier.Name("reverse")) // drop all items in the tail after the first time fn returns true // as a result, we have 0 or 1 items where fn is true in the result @@ -150,53 +183,58 @@ object Matchless { def stopAt[A](nel: NonEmptyList[A])(fn: A => Boolean): NonEmptyList[A] = nel match { case NonEmptyList(h, _) if fn(h) => NonEmptyList(h, Nil) - case s@NonEmptyList(_, Nil) => s - case NonEmptyList(h0, h1 :: t) => h0 :: stopAt(NonEmptyList(h1, t))(fn) + case s @ NonEmptyList(_, Nil) => s + case NonEmptyList(h0, h1 :: t) => h0 :: stopAt(NonEmptyList(h1, t))(fn) } // same as fromLet below, but uses RefSpace - def fromLet[A]( - name: Bindable, - rec: RecursionKind, - te: TypedExpr[A])( - variantOf: (PackageName, Constructor) => Option[DataRepr]): Expr = - (for { - c <- RefSpace.allocCounter - expr <- fromLet(name, rec, te, variantOf, c) - } yield expr).run.value + def fromLet[A](name: Bindable, rec: RecursionKind, te: TypedExpr[A])( + variantOf: (PackageName, Constructor) => Option[DataRepr] + ): Expr = + (for { + c <- RefSpace.allocCounter + expr <- fromLet(name, rec, te, variantOf, c) + } yield expr).run.value // we need a TypeEnv to inline the creation of structs and variants def fromLet[F[_]: Monad, A]( - name: Bindable, - rec: RecursionKind, - te: TypedExpr[A], - variantOf: (PackageName, Constructor) => Option[DataRepr], - makeAnon: F[Long]): F[Expr] = { + name: Bindable, + rec: RecursionKind, + te: TypedExpr[A], + variantOf: (PackageName, Constructor) => Option[DataRepr], + makeAnon: F[Long] + ): F[Expr] = { - type UnionMatch = NonEmptyList[(List[LocalAnonMut], BoolExpr, List[(Bindable, Expr)])] - val wildMatch: UnionMatch = NonEmptyList((Nil, TrueConst, Nil), Nil) + type UnionMatch = + NonEmptyList[(List[LocalAnonMut], BoolExpr, List[(Bindable, Expr)])] + val wildMatch: UnionMatch = NonEmptyList((Nil, TrueConst, Nil), Nil) val emptyExpr: Expr = empty match { case (p, c) => variantOf(p, c) match { case Some(DataRepr.Enum(v, s, f)) => MakeEnum(v, s, f) - case other => + case other => /* We assume the structure of Lists to be standard linked lists * Empty cannot be a struct */ // $COVERAGE-OFF$ - throw new IllegalStateException(s"empty List should be an enum, found: $other") - // $COVERAGE-ON$ + throw new IllegalStateException( + s"empty List should be an enum, found: $other" + ) + // $COVERAGE-ON$ } } - def loopLetVal(name: Bindable, e: TypedExpr[A], rec: RecursionKind): F[Expr] = { + def loopLetVal( + name: Bindable, + e: TypedExpr[A], + rec: RecursionKind + ): F[Expr] = { lazy val e0 = loop(e) rec match { case RecursionKind.Recursive => - def letrec(e: Expr): Expr = Let(Right((name, RecursionKind.Recursive)), e, Local(name)) @@ -217,8 +255,7 @@ object Matchless { // but it definitely does in fuzz tests e0.map(letrec) } - } - else { + } else { // otherwise let rec x = fn in x e0.map(letrec) } @@ -228,25 +265,27 @@ object Matchless { def loop(te: TypedExpr[A]): F[Expr] = te match { - case TypedExpr.Generic(_, expr) => loop(expr) + case TypedExpr.Generic(_, expr) => loop(expr) case TypedExpr.Annotation(term, _) => loop(term) case TypedExpr.AnnotatedLambda(args, res, _) => val captures = TypedExpr.freeVars(te :: Nil) loop(res).map(Lambda(captures, args.map(_._1), _)) - case TypedExpr.Global(pack, cons@Constructor(_), _, _) => + case TypedExpr.Global(pack, cons @ Constructor(_), _, _) => Monad[F].pure(variantOf(pack, cons) match { case Some(dr) => dr match { case DataRepr.Enum(v, a, f) => MakeEnum(v, a, f) - case DataRepr.Struct(a) => MakeStruct(a) - case DataRepr.NewType => MakeStruct(1) - case DataRepr.ZeroNat => ZeroNat - case DataRepr.SuccNat => SuccNat + case DataRepr.Struct(a) => MakeStruct(a) + case DataRepr.NewType => MakeStruct(1) + case DataRepr.ZeroNat => ZeroNat + case DataRepr.SuccNat => SuccNat } - // $COVERAGE-OFF$ + // $COVERAGE-OFF$ case None => - throw new IllegalStateException(s"could not find $cons in global data types") - // $COVERAGE-ON$ + throw new IllegalStateException( + s"could not find $cons in global data types" + ) + // $COVERAGE-ON$ }) case TypedExpr.Global(pack, notCons: Bindable, _, _) => Monad[F].pure(Global(pack, notCons)) @@ -258,8 +297,10 @@ object Matchless { (loopLetVal(a, e, r), loop(in)).mapN(Let(Right((a, r)), _, _)) case TypedExpr.Literal(lit, _, _) => Monad[F].pure(Literal(lit)) case TypedExpr.Match(arg, branches, _) => - (loop(arg), branches.traverse { case (p, te) => loop(te).map((p, _)) }) - .tupled + ( + loop(arg), + branches.traverse { case (p, te) => loop(te).map((p, _)) } + ).tupled .flatMap { case (a, b) => matchExpr(a, makeAnon, b) } } @@ -269,9 +310,11 @@ object Matchless { * 2. a total binding to a given name * 3. or we return None indicating not one of these */ - def maybeSimple(p: Pattern[(PackageName, Constructor), Type]): Option[Either[Bindable, Unit]] = + def maybeSimple( + p: Pattern[(PackageName, Constructor), Type] + ): Option[Either[Bindable, Unit]] = p match { - case Pattern.WildCard => Some(Right(())) + case Pattern.WildCard => Some(Right(())) case Pattern.Literal(_) => // Literals are never total None @@ -279,21 +322,21 @@ object Matchless { case Pattern.Named(v, p) => maybeSimple(p) match { case Some(Right(_)) => Some(Left(v)) - case _ => None + case _ => None } case Pattern.StrPat(s) => s match { case NonEmptyList(Pattern.StrPart.WildStr, Nil) => Some(Right(())) case NonEmptyList(Pattern.StrPart.NamedStr(n), Nil) => Some(Left(n)) - case _ => None + case _ => None } case Pattern.ListPat(l) => l match { - case Pattern.ListPart.WildList :: Nil => Some(Right(())) + case Pattern.ListPart.WildList :: Nil => Some(Right(())) case Pattern.ListPart.NamedList(n) :: Nil => Some(Left(n)) - case _ => None + case _ => None } - case Pattern.Annotation(p, _) => maybeSimple(p) + case Pattern.Annotation(p, _) => maybeSimple(p) case Pattern.PositionalStruct((pack, cname), ps) => // Only branch-free structs with no inner names are simple variantOf(pack, cname) match { @@ -306,10 +349,12 @@ object Matchless { } case _ => None } - // $COVERAGE-OFF$ + // $COVERAGE-OFF$ case None => - throw new IllegalStateException(s"could not find $cons in global data types") - // $COVERAGE-ON$ + throw new IllegalStateException( + s"could not find $cons in global data types" + ) + // $COVERAGE-ON$ } case Pattern.Union(h, t) => (h :: t.toList).traverse(maybeSimple).flatMap { inners => @@ -320,7 +365,11 @@ object Matchless { // return the check expression for the check we need to do, and the list of bindings // if must match is true, we know that the pattern must match, so we can potentially remove some checks - def doesMatch(arg: CheapExpr, pat: Pattern[(PackageName, Constructor), Type], mustMatch: Boolean): F[UnionMatch] = { + def doesMatch( + arg: CheapExpr, + pat: Pattern[(PackageName, Constructor), Type], + mustMatch: Boolean + ): F[UnionMatch] = { pat match { case Pattern.WildCard => // this is a total pattern @@ -335,37 +384,34 @@ object Matchless { }) case Pattern.StrPat(items) => val sbinds: List[Bindable] = - items - .toList + items.toList .collect { // that each name is distinct // should be checked in the SourceConverter/TotalityChecking code case Pattern.StrPart.NamedStr(n) => n } - val muts = sbinds.traverse { b => makeAnon.map(LocalAnonMut(_)).map((b, _)) } + val muts = sbinds.traverse { b => + makeAnon.map(LocalAnonMut(_)).map((b, _)) + } val pat = items.toList.map { - case Pattern.StrPart.NamedStr(_) => StrPart.IndexStr - case Pattern.StrPart.WildStr => StrPart.WildStr - case Pattern.StrPart.LitStr(s) => StrPart.LitStr(s) - } + case Pattern.StrPart.NamedStr(_) => StrPart.IndexStr + case Pattern.StrPart.WildStr => StrPart.WildStr + case Pattern.StrPart.LitStr(s) => StrPart.LitStr(s) + } muts.map { binds => val ms = binds.map(_._2) - NonEmptyList.of((ms, - MatchString( - arg, - pat, - ms), - binds)) + NonEmptyList.of((ms, MatchString(arg, pat, ms), binds)) } - case lp@Pattern.ListPat(_) => - + case lp @ Pattern.ListPat(_) => lp.toPositionalStruct(empty, cons) match { case Right(p) => doesMatch(arg, p, mustMatch) - case Left((glob, right@NonEmptyList(Pattern.ListPart.Item(_), _))) => + case Left( + (glob, right @ NonEmptyList(Pattern.ListPart.Item(_), _)) + ) => // we have a non-trailing list pattern // to match, this becomes a search problem // we loop over all the matches of p in the list, @@ -383,8 +429,7 @@ object Matchless { makeAnon.map { nm => Some((LocalAnonMut(nm), ln)) } } - (leftF, makeAnon) - .tupled + (leftF, makeAnon).tupled .flatMap { case (optAnonLeft, tmpList) => val anonList = LocalAnonMut(tmpList) @@ -397,26 +442,38 @@ object Matchless { // this shouldn't be possible, since there are no total list matches with // one item since we recurse on a ListPat with the first item being Right // which as we can see above always returns Some(_) - throw new IllegalStateException(s"$right should not be a total match") - // $COVERAGE-ON$ + throw new IllegalStateException( + s"$right should not be a total match" + ) + // $COVERAGE-ON$ case (preLet, expr, binds) => - val letTail = anonList :: preLet val (resLet, leftOpt, resBind) = optAnonLeft match { case Some((anonLeft, ln)) => - val revList = App(reverseFn, NonEmptyList.one(anonLeft)) - (anonLeft :: letTail, Some(anonLeft), (ln, revList) :: binds) + val revList = + App(reverseFn, NonEmptyList.one(anonLeft)) + ( + anonLeft :: letTail, + Some(anonLeft), + (ln, revList) :: binds + ) case None => (letTail, None, binds) } - (resLet, SearchList(anonList, arg, expr, leftOpt), resBind) + ( + resLet, + SearchList(anonList, arg, expr, leftOpt), + resBind + ) } } } - case Left((glob, right@NonEmptyList(_: Pattern.ListPart.Glob, _))) => + case Left( + (glob, right @ NonEmptyList(_: Pattern.ListPart.Glob, _)) + ) => // we search on the right side, so the left will match nothing // this should be banned by SourceConverter/TotalityChecker because // it is confusing, but it can be handled @@ -435,7 +492,7 @@ object Matchless { } } } - // $COVERAGE-ON$ + // $COVERAGE-ON$ } case Pattern.Annotation(p, _) => @@ -445,29 +502,40 @@ object Matchless { // we assume the patterns have already been optimized // so that useless total patterns have been replaced with _ type Locals = Chain[(LocalAnonMut, Expr)] - def asStruct(getter: Int => CheapExpr): WriterT[F, Locals, UnionMatch] = { + def asStruct( + getter: Int => CheapExpr + ): WriterT[F, Locals, UnionMatch] = { // we have an and of a series of ors: // (m1 + m2 + m3) * (m4 + m5 + m6) ... = // we need to multiply them all out into a single set of ors - def operate(pat: Pattern[(PackageName, Constructor), Type], idx: Int): WriterT[F, Locals, UnionMatch] = + def operate( + pat: Pattern[(PackageName, Constructor), Type], + idx: Int + ): WriterT[F, Locals, UnionMatch] = maybeSimple(pat) match { case Some(Right(())) => // this is a total match WriterT.value(wildMatch) case Some(Left(v)) => // this is just an alias - WriterT.value(NonEmptyList((Nil, TrueConst, (v, getter(idx)) :: Nil), Nil)) + WriterT.value( + NonEmptyList((Nil, TrueConst, (v, getter(idx)) :: Nil), Nil) + ) case None => // we make an anonymous variable and write to that: for { nm <- WriterT.valueT[F, Locals, Long](makeAnon) lam = LocalAnonMut(nm) - um <- WriterT.valueT[F, Locals, UnionMatch](doesMatch(lam, pat, mustMatch)) + um <- WriterT.valueT[F, Locals, UnionMatch]( + doesMatch(lam, pat, mustMatch) + ) // if this is a total match, we don't need to do the getter at all - chain = if (um == wildMatch) Chain.empty else Chain.one((lam, getter(idx))) + chain = + if (um == wildMatch) Chain.empty + else Chain.one((lam, getter(idx))) _ <- WriterT.tell[F, Locals](chain) } yield um - } + } val ands: WriterT[F, Locals, List[UnionMatch]] = params.zipWithIndex @@ -475,20 +543,24 @@ object Matchless { ands.map(NonEmptyList.fromList(_) match { case None => wildMatch - case Some(nel) => product(nel) { case ((l1, o1, b1), (l2, o2, b2)) => - (l1 ::: l2, o1 && o2, b1 ::: b2) - } + case Some(nel) => + product(nel) { case ((l1, o1, b1), (l2, o2, b2)) => + (l1 ::: l2, o1 && o2, b1 ::: b2) + } }) } def forStruct(size: Int) = - asStruct { pos => GetStructElement(arg, pos, size) } - .run + asStruct { pos => GetStructElement(arg, pos, size) }.run .map { case (anons, ums) => ums.map { case (pre, cond, bind) => - val pre1 = anons.foldLeft(pre) { case (pre, (a, _)) => a :: pre } + val pre1 = anons.foldLeft(pre) { case (pre, (a, _)) => + a :: pre + } // we have to set these variables before we can evaluate the condition - val cond1 = anons.foldLeft(cond) { case (c, (a, e)) => SetMut(a, e) && c } + val cond1 = anons.foldLeft(cond) { case (c, (a, e)) => + SetMut(a, e) && c + } (pre1, cond1, bind) } } @@ -496,73 +568,97 @@ object Matchless { variantOf(pack, cname) match { case Some(dr) => dr match { - case DataRepr.Struct(size) => forStruct(size) - case DataRepr.NewType => forStruct(1) + case DataRepr.Struct(size) => forStruct(size) + case DataRepr.NewType => forStruct(1) case DataRepr.Enum(vidx, size, f) => // if we match the variant, then treat it as a struct - val cv: BoolExpr = if (mustMatch) TrueConst else CheckVariant(arg, vidx, size, f) - asStruct { pos => GetEnumElement(arg, vidx, pos, size) } - .run + val cv: BoolExpr = + if (mustMatch) TrueConst + else CheckVariant(arg, vidx, size, f) + asStruct { pos => GetEnumElement(arg, vidx, pos, size) }.run .map { case (anons, ums) => if (ums == wildMatch) { // we just need to check the variant - assert(anons.isEmpty, "anons must by construction always be empty on wildMatch") + assert( + anons.isEmpty, + "anons must by construction always be empty on wildMatch" + ) NonEmptyList((Nil, cv, Nil), Nil) - } - else { + } else { // now we need to set up the binds if the variant is right - val cond1 = anons.foldLeft(cv) { case (c, (mut, expr)) => - c && SetMut(mut, expr) + val cond1 = anons.foldLeft(cv) { + case (c, (mut, expr)) => + c && SetMut(mut, expr) } ums.map { case (pre, cond, b) => - val pre1 = anons.foldLeft(pre) { case (pre, (mut, _)) => mut :: pre } + val pre1 = anons.foldLeft(pre) { + case (pre, (mut, _)) => mut :: pre + } (pre1, cond1 && cond, b) } } } case DataRepr.ZeroNat => - val cv: BoolExpr = if (mustMatch) TrueConst else EqualsNat(arg, DataRepr.ZeroNat) + val cv: BoolExpr = + if (mustMatch) TrueConst + else EqualsNat(arg, DataRepr.ZeroNat) Monad[F].pure(NonEmptyList((Nil, cv, Nil), Nil)) case DataRepr.SuccNat => params match { case single :: Nil => // if we match, we recur on the inner pattern and prev of current - val check = if (mustMatch) TrueConst else EqualsNat(arg, DataRepr.SuccNat) + val check = + if (mustMatch) TrueConst + else EqualsNat(arg, DataRepr.SuccNat) for { nm <- makeAnon loc = LocalAnonMut(nm) prev = PrevNat(arg) rest <- doesMatch(loc, single, mustMatch) - } yield rest.map { case (preLets, cond, res) => (loc ::preLets, check && SetMut(loc, prev) && cond, res) } + } yield rest.map { case (preLets, cond, res) => + ( + loc :: preLets, + check && SetMut(loc, prev) && cond, + res + ) + } case other => // $COVERAGE-OFF$ - throw new IllegalStateException(s"expected typechecked Nat to only have one param, found: $other in $pat") - // $COVERAGE-ON$ + throw new IllegalStateException( + s"expected typechecked Nat to only have one param, found: $other in $pat" + ) + // $COVERAGE-ON$ } } case None => // $COVERAGE-OFF$ - throw new IllegalStateException(s"could not find $cons in global data types") - // $COVERAGE-ON$ - } + throw new IllegalStateException( + s"could not find $cons in global data types" + ) + // $COVERAGE-ON$ + } case Pattern.Union(h, ts) => // note this list is exactly as long as h :: ts - val unionMustMatch = NonEmptyList.fromListUnsafe(List.fill(ts.size)(false) ::: mustMatch :: Nil) - ((h :: ts).zip(unionMustMatch)).traverse { case (p, mm) => doesMatch(arg, p, mm) }.map { nene => - val nel = nene.flatten - // at the first total match, we can stop - stopAt(nel) { - case (_, TrueConst, _) => true - case _ => false + val unionMustMatch = NonEmptyList.fromListUnsafe( + List.fill(ts.size)(false) ::: mustMatch :: Nil + ) + ((h :: ts) + .zip(unionMustMatch)) + .traverse { case (p, mm) => doesMatch(arg, p, mm) } + .map { nene => + val nel = nene.flatten + // at the first total match, we can stop + stopAt(nel) { + case (_, TrueConst, _) => true + case _ => false + } } - } } } def lets(binds: List[(Bindable, Expr)], in: Expr): Expr = binds.foldRight(in) { case ((b, e), r) => - val arg = Right((b, RecursionKind.NonRecursive)) Let(arg, e, r) } @@ -572,12 +668,27 @@ object Matchless { LetMut(anon, rest) } - def matchExpr(arg: Expr, tmp: F[Long], branches: NonEmptyList[(Pattern[(PackageName, Constructor), Type], Expr)]): F[Expr] = { - - def recur(arg: CheapExpr, branches: NonEmptyList[(Pattern[(PackageName, Constructor), Type], Expr)]): F[Expr] = { + def matchExpr( + arg: Expr, + tmp: F[Long], + branches: NonEmptyList[ + (Pattern[(PackageName, Constructor), Type], Expr) + ] + ): F[Expr] = { + + def recur( + arg: CheapExpr, + branches: NonEmptyList[ + (Pattern[(PackageName, Constructor), Type], Expr) + ] + ): F[Expr] = { val (p1, r1) = branches.head - def loop(cbs: NonEmptyList[(List[LocalAnonMut], BoolExpr, List[(Bindable, Expr)])]): F[Expr] = + def loop( + cbs: NonEmptyList[ + (List[LocalAnonMut], BoolExpr, List[(Bindable, Expr)]) + ] + ): F[Expr] = cbs match { case NonEmptyList((b0, TrueConst, binds), _) => // this is a total match, no fall through @@ -620,7 +731,10 @@ object Matchless { // toy matcher to see the structure // Left means match any number of items, like *_ - def matchList[A, B: Monoid](items: List[A], pattern: List[Either[List[A] => B, A => Option[B]]]): Option[B] = + def matchList[A, B: Monoid]( + items: List[A], + pattern: List[Either[List[A] => B, A => Option[B]]] + ): Option[B] = pattern match { case Nil => if (items.isEmpty) Some(Monoid[B].empty) @@ -629,7 +743,7 @@ object Matchless { items match { case ih :: it => fn(ih) match { - case None => None + case None => None case Some(b) => matchList(it, pt).map(Monoid[B].combine(b, _)) } case Nil => None @@ -638,13 +752,13 @@ object Matchless { case Left(lstFn) :: Nil => Some(lstFn(items)) - case Left(lstFn) :: (pt@(Left(_) :: _)) => + case Left(lstFn) :: (pt @ (Left(_) :: _)) => // it is ambiguous how much to absorb // so, just assume lstFn gets nothing matchList(items, pt) .map(Monoid.combine(lstFn(Nil), _)) - case Left(lstFn) :: (pt@(Right(_) :: _))=> + case Left(lstFn) :: (pt @ (Right(_) :: _)) => var revLeft: List[A] = Nil var it = items var result: Option[B] = None @@ -660,11 +774,11 @@ object Matchless { } } result - /* - * The above should be an imperative version - * of this code. The imperative code - * is easier to translate into low level - * instructions + /* + * The above should be an imperative version + * of this code. The imperative code + * is easier to translate into low level + * instructions items .toStream .mapWithIndex { (a, idx) => afn(a).map((_, idx)) } @@ -680,13 +794,14 @@ object Matchless { } } .headOption - */ + */ } - /** - * return the expanded product of sums - */ - def product[A1](sum: NonEmptyList[NonEmptyList[A1]])(prod: (A1, A1) => A1): NonEmptyList[A1] = + /** return the expanded product of sums + */ + def product[A1]( + sum: NonEmptyList[NonEmptyList[A1]] + )(prod: (A1, A1) => A1): NonEmptyList[A1] = sum match { case NonEmptyList(h, Nil) => // this (a1 + a2 + a3) case diff --git a/core/src/main/scala/org/bykn/bosatsu/MatchlessFromTypedExpr.scala b/core/src/main/scala/org/bykn/bosatsu/MatchlessFromTypedExpr.scala index b449f6632..dd269e719 100644 --- a/core/src/main/scala/org/bykn/bosatsu/MatchlessFromTypedExpr.scala +++ b/core/src/main/scala/org/bykn/bosatsu/MatchlessFromTypedExpr.scala @@ -6,34 +6,35 @@ import cats.implicits._ object MatchlessFromTypedExpr { // compile a set of packages given a set of external remappings - def compile[A](pm: PackageMap.Typed[A])(implicit ec: Par.EC): Map[PackageName, List[(Bindable, Matchless.Expr)]] = { + def compile[A]( + pm: PackageMap.Typed[A] + )(implicit ec: Par.EC): Map[PackageName, List[(Bindable, Matchless.Expr)]] = { val gdr = pm.getDataRepr // on JS Par.F[A] is actually Id[A], so we need to hold hands a bit - val allItemsList = pm.toMap - .toList - .traverse[Par.F, (PackageName, List[(Bindable, Matchless.Expr)])] { case (pname, pack) => - val lets = pack.program.lets - - Par.start { - val exprs: List[(Bindable, Matchless.Expr)] = - rankn.RefSpace - .allocCounter - .flatMap { c => - lets - .traverse { - case (name, rec, te) => - Matchless.fromLet(name, rec, te, gdr, c) - .map((name, _)) - } - } - .run - .value - - (pname, exprs) - } + val allItemsList = pm.toMap.toList + .traverse[Par.F, (PackageName, List[(Bindable, Matchless.Expr)])] { + case (pname, pack) => + val lets = pack.program.lets + + Par.start { + val exprs: List[(Bindable, Matchless.Expr)] = + rankn.RefSpace.allocCounter + .flatMap { c => + lets + .traverse { case (name, rec, te) => + Matchless + .fromLet(name, rec, te, gdr, c) + .map((name, _)) + } + } + .run + .value + + (pname, exprs) + } } // JS needs this to not see through the Par.F as Id diff --git a/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala b/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala index 93f099b44..d95155928 100644 --- a/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala +++ b/core/src/main/scala/org/bykn/bosatsu/MatchlessToValue.scala @@ -16,7 +16,9 @@ object MatchlessToValue { import Matchless._ // reuse some cache structures across a number of calls - def traverse[F[_]: Functor](me: F[Expr])(resolve: (PackageName, Identifier) => Eval[Value]): F[Eval[Value]] = { + def traverse[F[_]: Functor]( + me: F[Expr] + )(resolve: (PackageName, Identifier) => Eval[Value]): F[Eval[Value]] = { val env = new Impl.Env(resolve) val fns = Functor[F].map(me) { expr => env.loop(expr) @@ -42,9 +44,10 @@ object MatchlessToValue { case MakeEnum(variant, arity, _) => if (arity == 0) SumValue(variant, UnitValue) else if (arity == 1) { - FnValue { case NonEmptyList(v, _) => SumValue(variant, ConsValue(v, UnitValue)) } - } - else + FnValue { case NonEmptyList(v, _) => + SumValue(variant, ConsValue(v, UnitValue)) + } + } else // arity > 1 FnValue { args => val prod = ProductValue.fromList(args.toList) @@ -53,9 +56,10 @@ object MatchlessToValue { case MakeStruct(arity) => if (arity == 0) UnitValue else if (arity == 1) FnValue.identity - else FnValue { args => - ProductValue.fromList(args.toList) - } + else + FnValue { args => + ProductValue.fromList(args.toList) + } case ZeroNat => zeroNat case SuccNat => succNat } @@ -65,15 +69,18 @@ object MatchlessToValue { val uninit: Value = ExternalValue(Uninitialized) final case class Scope( - locals: Map[Bindable, Eval[Value]], - anon: LongMap[Value], - muts: MLongMap[Value]) { + locals: Map[Bindable, Eval[Value]], + anon: LongMap[Value], + muts: MLongMap[Value] + ) { def let(b: Bindable, v: Eval[Value]): Scope = copy(locals = locals.updated(b, v)) def letAll[F[_]: Foldable](bs: F[(Bindable, Value)]): Scope = - copy(locals = bs.foldLeft(locals) { case (locals, (b, v)) => locals.updated(b, Eval.now(v)) }) + copy(locals = bs.foldLeft(locals) { case (locals, (b, v)) => + locals.updated(b, Eval.now(v)) + }) def updateMut(mutIdx: Long, v: Value): Unit = { assert(muts.contains(mutIdx)) @@ -85,13 +92,17 @@ object MatchlessToValue { def loc(b: Bindable): Eval[Value] = locals.get(b) match { case Some(v) => v - case None => sys.error(s"couldn't find: $b in ${locals.keys.map(_.asString).toList} capturing: ${it.toList}") + case None => + sys.error( + s"couldn't find: $b in ${locals.keys.map(_.asString).toList} capturing: ${it.toList}" + ) } Scope( it.iterator.map { b => (b, loc(b)) }.toMap, - LongMap.empty, - MLongMap()) + LongMap.empty, + MLongMap() + ) } } @@ -102,7 +113,9 @@ object MatchlessToValue { sealed abstract class Scoped[A] { def apply(s: Scope): A def map[B](fn: A => B): Scoped[B] - def and(that: Scoped[Boolean])(implicit ev: Is[A, Boolean]): Scoped[Boolean] = { + def and( + that: Scoped[Boolean] + )(implicit ev: Is[A, Boolean]): Scoped[Boolean] = { // boolean conditions are generally never static, so we can't easily exercise // this code if we specialize it. So, we assume it is dynamic here val thisBool = ev.substitute[Scoped](this) @@ -135,7 +148,9 @@ object MatchlessToValue { def pure[A](a: A): Scoped[A] = Static(a) override def map[A, B](aa: Scoped[A])(fn: A => B): Scoped[B] = aa.map(fn) - override def map2[A, B, C](aa: Scoped[A], ab: Scoped[B])(fn: (A, B) => C): Scoped[C] = + override def map2[A, B, C](aa: Scoped[A], ab: Scoped[B])( + fn: (A, B) => C + ): Scoped[C] = (aa, ab) match { case (Static(a), Static(b)) => Static(fn(a, b)) case (Static(a), db) => @@ -157,7 +172,6 @@ object MatchlessToValue { private def boolExpr(ix: BoolExpr): Scoped[Boolean] = ix match { case EqualsLit(expr, lit) => - val litAny = lit.unboxToAny loop(expr).map { e => @@ -200,7 +214,9 @@ object MatchlessToValue { matchString(arg, pat, 0) != null } case _ => - val bary = binds.iterator.collect { case LocalAnonMut(id) => id }.toArray + val bary = binds.iterator.collect { case LocalAnonMut(id) => + id + }.toArray // this may be static val matchScope = loop(str).map { str => @@ -217,8 +233,7 @@ object MatchlessToValue { idx = idx + 1 } true - } - else false + } else false } } @@ -238,19 +253,24 @@ object MatchlessToValue { var res = false while (currentList ne null) { currentList match { - case nonempty@VList.Cons(_, tail) => + case nonempty @ VList.Cons(_, tail) => scope.updateMut(mutV, nonempty) res = checkF(scope) if (res) { currentList = null } else { currentList = tail } case _ => currentList = null - // we don't match empty lists + // we don't match empty lists } } res } - case SearchList(LocalAnonMut(mutV), init, check, Some(LocalAnonMut(left))) => + case SearchList( + LocalAnonMut(mutV), + init, + check, + Some(LocalAnonMut(left)) + ) => val initF = loop(init) val checkF = boolExpr(check) @@ -261,7 +281,7 @@ object MatchlessToValue { var leftList = VList.VNil while (currentList ne null) { currentList match { - case nonempty@VList.Cons(head, tail) => + case nonempty @ VList.Cons(head, tail) => scope.updateMut(mutV, nonempty) scope.updateMut(left, leftList) res = checkF(scope) @@ -272,14 +292,19 @@ object MatchlessToValue { } case _ => currentList = null - // we don't match empty lists + // we don't match empty lists } } res } } - def buildLoop(caps: List[Bindable], fnName: Bindable, args: NonEmptyList[Bindable], body: Scoped[Value]): Scoped[Value] = { + def buildLoop( + caps: List[Bindable], + fnName: Bindable, + args: NonEmptyList[Bindable], + body: Scoped[Value] + ): Scoped[Value] = { val argCount = args.length val argNames: Array[Bindable] = args.toList.toArray if ((caps.lengthCompare(1) == 0) && (caps.head == fnName)) { @@ -319,8 +344,7 @@ object MatchlessToValue { } Static(fn) - } - else { + } else { Dynamic { scope => // TODO this maybe isn't helpful // it doesn't matter if the scope @@ -381,8 +405,7 @@ object MatchlessToValue { resFn(scope2) } Static(fn) - } - else { + } else { Dynamic { scope => val scope1 = scope.capture(caps) // hopefully optimization/normalization has lifted anything @@ -403,8 +426,8 @@ object MatchlessToValue { // this has to be lazy because it could be // in this package, which isn't complete yet Dynamic { (_: Scope) => res.value } - case Local(b) => Dynamic(_.locals(b).value) - case LocalAnon(a) => Dynamic(_.anon(a)) + case Local(b) => Dynamic(_.locals(b).value) + case LocalAnon(a) => Dynamic(_.anon(a)) case LocalAnonMut(m) => Dynamic(_.muts(m)) case App(expr, args) => // TODO: App(LoopFn(.. @@ -418,7 +441,8 @@ object MatchlessToValue { Applicative[Scoped].map2(exprFn, argsFn) { (fn, args) => fn.applyAll(args) } - case Let(Right((n1, r)), loopFn@LoopFn(_, n2, _, _), Local(n3)) if (n1 === n3) && (n1 === n2) && r.isRecursive => + case Let(Right((n1, r)), loopFn @ LoopFn(_, n2, _, _), Local(n3)) + if (n1 === n3) && (n1 === n2) && r.isRecursive => // LoopFn already correctly handles recursion loop(loopFn) case Let(localOrBind, value, in) => @@ -440,8 +464,7 @@ object MatchlessToValue { scope1 } - } - else { + } else { inF.withScope { (scope: Scope) => val vv = Eval.now(valueF(scope)) scope.let(b, vv) @@ -455,7 +478,7 @@ object MatchlessToValue { } case LetMut(LocalAnonMut(l), in) => loop(in) match { - case s@Static(_) => s + case s @ Static(_) => s case Dynamic(inF) => Dynamic { (scope: Scope) => // we make sure there is @@ -508,8 +531,7 @@ object MatchlessToValue { if (sz == 1) { // this is a newtype loopFn - } - else { + } else { loop(expr).map { p => p.asProduct.get(idx) } @@ -530,17 +552,30 @@ object MatchlessToValue { private[this] val emptyStringArray: Array[String] = new Array[String](0) - def matchString(str: String, pat: List[Matchless.StrPart], binds: Int): Array[String] = { + def matchString( + str: String, + pat: List[Matchless.StrPart], + binds: Int + ): Array[String] = { import Matchless.StrPart._ - val results = if (binds > 0) new Array[String](binds) else emptyStringArray + val results = + if (binds > 0) new Array[String](binds) else emptyStringArray - def loop(offset: Int, pat: List[Matchless.StrPart], next: Int): Boolean = + def loop( + offset: Int, + pat: List[Matchless.StrPart], + next: Int + ): Boolean = pat match { case Nil => offset == str.length case LitStr(expect) :: tail => val len = expect.length - str.regionMatches(offset, expect, 0, len) && loop(offset + len, tail, next) + str.regionMatches(offset, expect, 0, len) && loop( + offset + len, + tail, + next + ) case (h: Glob) :: tail => tail match { case Nil => @@ -566,13 +601,11 @@ object MatchlessToValue { } result = true start = -1 - } - else { + } else { // we couldn't match here, try just after candidate start = candidate + 1 } - } - else { + } else { // no more candidates start = -1 } diff --git a/core/src/main/scala/org/bykn/bosatsu/MemoryMain.scala b/core/src/main/scala/org/bykn/bosatsu/MemoryMain.scala index c61942afe..2c3521436 100644 --- a/core/src/main/scala/org/bykn/bosatsu/MemoryMain.scala +++ b/core/src/main/scala/org/bykn/bosatsu/MemoryMain.scala @@ -7,27 +7,33 @@ import scala.collection.immutable.SortedMap import cats.implicits._ -class MemoryMain[F[_], K: Ordering](split: K => List[String])( - implicit val pathArg: Argument[K], - val innerMonad: MonadError[F, Throwable]) extends MainModule[Kleisli[F, MemoryMain.State[K], *]] { +class MemoryMain[F[_], K: Ordering](split: K => List[String])(implicit + val pathArg: Argument[K], + val innerMonad: MonadError[F, Throwable] +) extends MainModule[Kleisli[F, MemoryMain.State[K], *]] { type IO[A] = Kleisli[F, MemoryMain.State[K], A] type Path = K def readPath(p: Path): IO[String] = - Kleisli.ask[F, MemoryMain.State[K]] + Kleisli + .ask[F, MemoryMain.State[K]] .flatMap { files => files.get(p) match { case Some(MemoryMain.FileContent.Str(res)) => moduleIOMonad.pure(res) - case other => moduleIOMonad.raiseError(new Exception(s"expect String content, found: $other")) + case other => + moduleIOMonad.raiseError( + new Exception(s"expect String content, found: $other") + ) } } def resolvePath: Option[(Path, PackageName) => IO[Option[Path]]] = None def readPackages(paths: List[Path]): IO[List[Package.Typed[Unit]]] = - Kleisli.ask[F, MemoryMain.State[K]] + Kleisli + .ask[F, MemoryMain.State[K]] .flatMap { files => paths .traverse { path => @@ -36,14 +42,16 @@ class MemoryMain[F[_], K: Ordering](split: K => List[String])( moduleIOMonad.pure(res) case other => moduleIOMonad.raiseError[List[Package.Typed[Unit]]]( - new Exception(s"expect Packages content, found: $other")) + new Exception(s"expect Packages content, found: $other") + ) } } .map(_.flatten) } def readInterfaces(paths: List[Path]): IO[List[Package.Interface]] = - Kleisli.ask[F, MemoryMain.State[K]] + Kleisli + .ask[F, MemoryMain.State[K]] .flatMap { files => paths .traverse { path => @@ -52,7 +60,8 @@ class MemoryMain[F[_], K: Ordering](split: K => List[String])( moduleIOMonad.pure(res) case other => moduleIOMonad.raiseError[List[Package.Interface]]( - new Exception(s"expect Packages content, found: $other")) + new Exception(s"expect Packages content, found: $other") + ) } } .map(_.flatten) @@ -66,22 +75,27 @@ class MemoryMain[F[_], K: Ordering](split: K => List[String])( def runWith( files: Iterable[(K, String)], packages: Iterable[(K, List[Package.Typed[Unit]])] = Nil, - interfaces: Iterable[(K, List[Package.Interface])] = Nil)(cmd: List[String]): F[Output] = - run(cmd) match { - case Left(_) => - innerMonad.raiseError[Output](new Exception(s"got the help message for: $cmd")) - case Right(io) => - val state0 = files.foldLeft(SortedMap.empty[K, MemoryMain.FileContent]) { case (st, (k, str)) => + interfaces: Iterable[(K, List[Package.Interface])] = Nil + )(cmd: List[String]): F[Output] = + run(cmd) match { + case Left(_) => + innerMonad.raiseError[Output]( + new Exception(s"got the help message for: $cmd") + ) + case Right(io) => + val state0 = + files.foldLeft(SortedMap.empty[K, MemoryMain.FileContent]) { + case (st, (k, str)) => st.updated(k, MemoryMain.FileContent.Str(str)) - } - val state1 = packages.foldLeft(state0) { case (st, (k, packs)) => - st.updated(k, MemoryMain.FileContent.Packages(packs)) - } - val state2 = interfaces.foldLeft(state1) { case (st, (k, ifs)) => - st.updated(k, MemoryMain.FileContent.Interfaces(ifs)) - } - io.run(state2) + } + val state1 = packages.foldLeft(state0) { case (st, (k, packs)) => + st.updated(k, MemoryMain.FileContent.Packages(packs)) } + val state2 = interfaces.foldLeft(state1) { case (st, (k, ifs)) => + st.updated(k, MemoryMain.FileContent.Interfaces(ifs)) + } + io.run(state2) + } def pathPackage(roots: List[Path], packFile: Path): Option[PackageName] = { val fparts = split(packFile) @@ -91,8 +105,7 @@ class MemoryMain[F[_], K: Ordering](split: K => List[String])( if (fparts.startsWith(splitP)) { val parts = fparts.drop(splitP.length) PackageName.parse(parts.mkString("/")) - } - else None + } else None } @annotation.tailrec @@ -113,7 +126,6 @@ class MemoryMain[F[_], K: Ordering](split: K => List[String])( Kleisli(_ => innerMonad.pure(a)) } - object MemoryMain { sealed abstract class FileContent object FileContent { diff --git a/core/src/main/scala/org/bykn/bosatsu/NameKind.scala b/core/src/main/scala/org/bykn/bosatsu/NameKind.scala index f0b9d06c9..9932cdc9c 100644 --- a/core/src/main/scala/org/bykn/bosatsu/NameKind.scala +++ b/core/src/main/scala/org/bykn/bosatsu/NameKind.scala @@ -4,14 +4,24 @@ import Identifier.Bindable sealed abstract class NameKind[T] object NameKind { - case class Let[T](name: Bindable, recursive: RecursionKind, value: TypedExpr[T]) extends NameKind[T] + case class Let[T]( + name: Bindable, + recursive: RecursionKind, + value: TypedExpr[T] + ) extends NameKind[T] case class Constructor[T]( - cn: Identifier.Constructor, - params: List[(Bindable, rankn.Type)], - defined: rankn.DefinedType[Kind.Arg], - valueType: rankn.Type) extends NameKind[T] - case class Import[T](fromPack: Package.Interface, originalName: Identifier) extends NameKind[T] - case class ExternalDef[T](pack: PackageName, defName: Identifier, defType: rankn.Type) extends NameKind[T] + cn: Identifier.Constructor, + params: List[(Bindable, rankn.Type)], + defined: rankn.DefinedType[Kind.Arg], + valueType: rankn.Type + ) extends NameKind[T] + case class Import[T](fromPack: Package.Interface, originalName: Identifier) + extends NameKind[T] + case class ExternalDef[T]( + pack: PackageName, + defName: Identifier, + defType: rankn.Type + ) extends NameKind[T] def externals[T](from: Package.Typed[T]): Iterable[ExternalDef[T]] = { val prog = from.program @@ -24,7 +34,10 @@ object NameKind { } } - def apply[T](from: Package.Typed[T], item: Identifier): Option[NameKind[T]] = { + def apply[T]( + from: Package.Typed[T], + item: Identifier + ): Option[NameKind[T]] = { val prog = from.program def getLet: Option[NameKind[T]] = diff --git a/core/src/main/scala/org/bykn/bosatsu/Operators.scala b/core/src/main/scala/org/bykn/bosatsu/Operators.scala index 17b327970..971d15b4f 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Operators.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Operators.scala @@ -11,7 +11,7 @@ object Operators { val leftDone = left.length <= idx val rightDone = right.length <= idx (leftDone, rightDone) match { - case (true, true) => 0 + case (true, true) => 0 case (true, false) => 1 case (false, true) => -1 case (false, false) => @@ -21,7 +21,8 @@ object Operators { else { Integer.compare( priorityMap.getOrElse(lc, Int.MaxValue), - priorityMap.getOrElse(rc, Int.MaxValue)) + priorityMap.getOrElse(rc, Int.MaxValue) + ) } } } @@ -30,26 +31,19 @@ object Operators { else loop(0) } - /** - * strings for operators allowed in single character - * operators (excludes = and .) - */ + /** strings for operators allowed in single character operators (excludes = + * and .) + */ val singleToks = - List( - "/", "%", "*", - "-", "+", - "<", ">", - "!", "$", - "&", "^", "|", - "?", "~").map(_.intern) + List("/", "%", "*", "-", "+", "<", ">", "!", "$", "&", "^", "|", "?", "~") + .map(_.intern) private def from(strs: Iterable[String]): P[Unit] = P.stringIn(strs).void - /** - * strings for operators allowed in single character - * operators includes singleToks and . and = - */ + /** strings for operators allowed in single character operators includes + * singleToks and . and = + */ val multiToks: List[String] = ".".intern :: singleToks ::: List("=".intern) @@ -57,22 +51,17 @@ object Operators { from(multiToks) private val priorityMap: Map[String, Int] = - multiToks - .iterator - .zipWithIndex - .toMap - - /** - * Here are a list of operators we allow - */ + multiToks.iterator.zipWithIndex.toMap + + /** Here are a list of operators we allow + */ val operatorToken: P[String] = { val singles = from(singleToks) // write this in a way to avoid backtracking (((P.string("<-") | P.char('=') | P.string("->")) ~ multiToksP.rep).void | (singles ~ multiToksP.rep0).void | - multiToksP.rep(min = 2).void) - .string + multiToksP.rep(min = 2).void).string .map(_.intern) } @@ -87,17 +76,19 @@ object Operators { object Formula { case class Sym[A](value: A) extends Formula[A] - case class Op[A](left: Formula[A], op: String, right: Formula[A]) extends Formula[A] - - /** - * 1 * 2 + 3 => (1 * 2) + 3 - * 1 * 2 * 3 => ((1 * 2) * 3) - */ - def toFormula[A](init: Formula[A], rest: List[(String, Formula[A])]): Formula[A] = + case class Op[A](left: Formula[A], op: String, right: Formula[A]) + extends Formula[A] + + /** 1 * 2 + 3 => (1 * 2) + 3 1 * 2 * 3 => ((1 * 2) * 3) + */ + def toFormula[A]( + init: Formula[A], + rest: List[(String, Formula[A])] + ): Formula[A] = rest match { - case Nil => init + case Nil => init case (op, next) :: Nil => Op(init, op, next) - case (op1, next1) :: (right@((op2, next2) :: tail)) => + case (op1, next1) :: (right @ ((op2, next2) :: tail)) => val c = compareOperator(op1, op2) if (c > 0) { // right binds tighter @@ -106,36 +97,35 @@ object Operators { // in this example, then starting again val f2 = Op(next1, op2, next2) toFormula(init, (op1, f2) :: tail) - } - else { + } else { // 1 + 2 + 3 => (1 + 2) + 3 // 1 * 2 + 3 => (1 * 2) + 3 toFormula(Op(init, op1, next1), right) } } - /** - * Parse a chain of at least 1 operator being applied - * with the operator precedence handled by the formula - */ + /** Parse a chain of at least 1 operator being applied with the operator + * precedence handled by the formula + */ def infixOps1[A](p: P[A]): P[A => Formula[A]] = { val opA = operatorToken ~ (Parser.maybeSpacesAndLines.with1 *> p) val chain: P[NonEmptyList[(String, A)]] = P.repSep(opA, min = 1, sep = Parser.maybeSpace) chain.map { rest => - - { (a: A) => toFormula(Sym(a), rest.toList.map { case (o, s) => (o, Sym(s)) }) } + { (a: A) => + toFormula(Sym(a), rest.toList.map { case (o, s) => (o, Sym(s)) }) + } } } - /** - * An a formula is a series of A's separated by spaces, with - * the correct parenthesis - */ + + /** An a formula is a series of A's separated by spaces, with the correct + * parenthesis + */ def parser[A](p: P[A]): P[Formula[A]] = (p ~ (Parser.maybeSpace.with1 *> infixOps1(p)).?) .map { - case (a, None) => Sym(a) + case (a, None) => Sym(a) case (a, Some(f)) => f(a) } } diff --git a/core/src/main/scala/org/bykn/bosatsu/OptIndent.scala b/core/src/main/scala/org/bykn/bosatsu/OptIndent.scala index 53e9af25f..7b1b67c61 100644 --- a/core/src/main/scala/org/bykn/bosatsu/OptIndent.scala +++ b/core/src/main/scala/org/bykn/bosatsu/OptIndent.scala @@ -14,7 +14,7 @@ sealed abstract class OptIndent[A] { def sepDoc: Doc = this match { - case OptIndent.SameLine(_) => Doc.space + case OptIndent.SameLine(_) => Doc.space case OptIndent.NotSameLine(_) => Doc.empty } @@ -43,7 +43,8 @@ object OptIndent { NotSameLine(toPadIndent) case class SameLine[A](get: A) extends OptIndent[A] - case class NotSameLine[A](toPadIndent: Padding[Indented[A]]) extends OptIndent[A] { + case class NotSameLine[A](toPadIndent: Padding[Indented[A]]) + extends OptIndent[A] { def get: A = toPadIndent.padded.value } @@ -56,7 +57,7 @@ object OptIndent { val dpi = Document[Padding[Indented[A]]] Document.instance[OptIndent[A]] { - case SameLine(a) => da.document(a) + case SameLine(a) => da.document(a) case NotSameLine(p) => dpi.document(p) } } @@ -64,20 +65,23 @@ object OptIndent { def indy[A](p: Indy[A]): Indy[OptIndent[A]] = { val ind = Indented.indy(p) // we need to read at least 1 new line here - val not = ind.mapF { p => Padding.parser1(p).map(notSame[A](_)): P[OptIndent[A]] } + val not = ind.mapF { p => + Padding.parser1(p).map(notSame[A](_)): P[OptIndent[A]] + } val sm = p.map(same[A](_)) not <+> sm } - /** - * A: B or - * A: - * B - */ + /** A: B or A: B + */ def block[A, B](first: Indy[A], next: Indy[B]): Indy[(A, OptIndent[B])] = blockLike(first, next, (maybeSpace ~ P.char(':')).void) - def blockLike[A, B](first: Indy[A], next: Indy[B], sep: P0[Unit]): Indy[(A, OptIndent[B])] = + def blockLike[A, B]( + first: Indy[A], + next: Indy[B], + sep: P0[Unit] + ): Indy[(A, OptIndent[B])] = first .cutLeftP(sep ~ maybeSpace) .cutThen(OptIndent.indy(next)) diff --git a/core/src/main/scala/org/bykn/bosatsu/Package.scala b/core/src/main/scala/org/bykn/bosatsu/Package.scala index 0d26e47bc..6be4d9358 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Package.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Package.scala @@ -11,14 +11,16 @@ import rankn._ import Parser.{spaces, Combinators} import FixType.Fix -/** - * Represents a package over its life-cycle: from parsed to resolved to inferred - */ + +/** Represents a package over its life-cycle: from parsed to resolved to + * inferred + */ final case class Package[A, B, C, +D]( - name: PackageName, - imports: List[Import[A, B]], - exports: List[ExportedName[C]], - program: D) { + name: PackageName, + imports: List[Import[A, B]], + exports: List[ExportedName[C]], + program: D +) { // It is really important to cache the hashcode and these large dags if // we use them as hash keys @@ -49,27 +51,28 @@ final case class Package[A, B, C, +D]( def mapProgram[D1](fn: D => D1): Package[A, B, C, D1] = Package(name, imports, exports, fn(program)) - def replaceImports[A1, B1](newImports: List[Import[A1, B1]]): Package[A1, B1, C, D] = + def replaceImports[A1, B1]( + newImports: List[Import[A1, B1]] + ): Package[A1, B1, C, D] = Package(name, newImports, exports, program) } object Package { type Interface = Package[Nothing, Nothing, Referant[Kind.Arg], Unit] - /** - * This is a package whose import type is Either: - * 1 a package of the same kind - * 2 an interface - */ + + /** This is a package whose import type is Either: 1 a package of the same + * kind 2 an interface + */ type FixPackage[B, C, D] = Fix[λ[a => Either[Interface, Package[a, B, C, D]]]] - type PackageF[A, B, C] = Either[Interface, Package[FixPackage[A, B, C], A, B, C]] + type PackageF[A, B, C] = + Either[Interface, Package[FixPackage[A, B, C], A, B, C]] type PackageF2[A, B] = PackageF[A, A, B] type Parsed = Package[PackageName, Unit, Unit, List[Statement]] - type Resolved = FixPackage[Unit, Unit, (List[Statement], ImportMap[PackageName, Unit])] - type Typed[T] = Package[ - Interface, - NonEmptyList[Referant[Kind.Arg]], - Referant[Kind.Arg], - Program[TypeEnv[Kind.Arg], TypedExpr[T], Any]] + type Resolved = + FixPackage[Unit, Unit, (List[Statement], ImportMap[PackageName, Unit])] + type Typed[T] = Package[Interface, NonEmptyList[Referant[Kind.Arg]], Referant[ + Kind.Arg + ], Program[TypeEnv[Kind.Arg], TypedExpr[T], Any]] type Inferred = Typed[Declaration] val typedFunctor: Functor[Typed] = @@ -82,37 +85,39 @@ object Package { } } - /** - * Return the last binding in the file with the test type - */ - def testValue[A](tp: Typed[A]): Option[(Identifier.Bindable, RecursionKind, TypedExpr[A])] = - tp - .program - .lets - .filter { case (_, _, te) => te.getType == Type.TestType } - .lastOption - - /** - * Discard any top level values that are not referenced, exported, - * the final test value, or the final expression - * - * This is used to remove private top levels that were inlined. - */ + /** Return the last binding in the file with the test type + */ + def testValue[A]( + tp: Typed[A] + ): Option[(Identifier.Bindable, RecursionKind, TypedExpr[A])] = + tp.program.lets.filter { case (_, _, te) => + te.getType == Type.TestType + }.lastOption + + /** Discard any top level values that are not referenced, exported, the final + * test value, or the final expression + * + * This is used to remove private top levels that were inlined. + */ def discardUnused[A](tp: Typed[A]): Typed[A] = { val pinned: Set[Identifier] = tp.exports.iterator.map(_.name).toSet ++ - tp.program.lets.lastOption.map(_._1) ++ + tp.program.lets.lastOption.map(_._1) ++ testValue(tp).map(_._1) def topLevels(s: Set[(PackageName, Identifier)]): Set[Identifier] = s.collect { case (p, i) if p === tp.name => i } - val letWithGlobals = tp.program.lets.map { case tup @ (_, _, te) => (tup, topLevels(te.globals)) } + val letWithGlobals = tp.program.lets.map { case tup @ (_, _, te) => + (tup, topLevels(te.globals)) + } @annotation.tailrec def loop(reached: Set[Identifier]): Set[Identifier] = { val step = letWithGlobals - .foldMap { case ((bn, _, _), tops) => if (reached(bn)) tops else Set.empty[Identifier] } + .foldMap { case ((bn, _, _), tops) => + if (reached(bn)) tops else Set.empty[Identifier] + } if (step.forall(reached)) reached else loop(step | reached) @@ -120,7 +125,9 @@ object Package { val reached = loop(pinned) - val reachedLets = letWithGlobals.collect { case (tup @ (bn, _, _), _) if reached(bn) => tup } + val reachedLets = letWithGlobals.collect { + case (tup @ (bn, _, _), _) if reached(bn) => tup + } tp.copy(program = tp.program.copy(lets = reachedLets)) } @@ -129,10 +136,10 @@ object Package { def unfix[A, B, C](fp: FixPackage[A, B, C]): PackageF[A, B, C] = FixType.unfix[λ[a => Either[Interface, Package[a, A, B, C]]]](fp) - /** - * build a Parsed Package from a Statement. This is useful for testing or - * library usages. - */ + + /** build a Parsed Package from a Statement. This is useful for testing or + * library usages. + */ def fromStatements(pn: PackageName, stmts: List[Statement]): Package.Parsed = Package(pn, Nil, Nil, stmts) @@ -142,98 +149,135 @@ object Package { def setProgramFrom[A, B](t: Typed[A], newFrom: B): Typed[A] = t.copy(program = t.program.copy(from = newFrom)) - implicit val document: Document[Package[PackageName, Unit, Unit, List[Statement]]] = - Document.instance[Package.Parsed] { case Package(name, imports, exports, statments) => - val p = Doc.text("package ") + Document[PackageName].document(name) + Doc.line - val i = imports match { - case Nil => Doc.empty - case nonEmptyImports => - Doc.line + - Doc.intercalate(Doc.line, nonEmptyImports.map(Document[Import[PackageName, Unit]].document _)) + - Doc.line - } - val e = exports match { - case Nil => Doc.empty - case nonEmptyExports => - Doc.line + - Doc.text("export ") + - Doc.intercalate(Doc.text(", "), nonEmptyExports.map(Document[ExportedName[Unit]].document _)) + - Doc.line - } - val b = statments.map(Document[Statement].document(_)) - Doc.intercalate(Doc.empty, p :: i :: e :: b) + implicit val document + : Document[Package[PackageName, Unit, Unit, List[Statement]]] = + Document.instance[Package.Parsed] { + case Package(name, imports, exports, statments) => + val p = + Doc.text("package ") + Document[PackageName].document(name) + Doc.line + val i = imports match { + case Nil => Doc.empty + case nonEmptyImports => + Doc.line + + Doc.intercalate( + Doc.line, + nonEmptyImports.map( + Document[Import[PackageName, Unit]].document _ + ) + ) + + Doc.line + } + val e = exports match { + case Nil => Doc.empty + case nonEmptyExports => + Doc.line + + Doc.text("export ") + + Doc.intercalate( + Doc.text(", "), + nonEmptyExports.map(Document[ExportedName[Unit]].document _) + ) + + Doc.line + } + val b = statments.map(Document[Statement].document(_)) + Doc.intercalate(Doc.empty, p :: i :: e :: b) } - def parser(defaultPack: Option[PackageName]): P0[Package[PackageName, Unit, Unit, List[Statement]]] = { + def parser( + defaultPack: Option[PackageName] + ): P0[Package[PackageName, Unit, Unit, List[Statement]]] = { // TODO: support comments before the Statement - val parsePack = Padding.parser((P.string("package").soft ~ spaces) *> PackageName.parser <* Parser.toEOL).map(_.padded) + val parsePack = Padding + .parser( + (P.string("package") + .soft ~ spaces) *> PackageName.parser <* Parser.toEOL + ) + .map(_.padded) val pname: P0[PackageName] = defaultPack match { - case None => parsePack + case None => parsePack case Some(p) => parsePack.?.map(_.getOrElse(p)) } val im = Padding.parser(Import.parser <* Parser.toEOL).map(_.padded).rep0 - val ex = Padding.parser((P.string("export").soft ~ spaces) *> ExportedName.parser.itemsMaybeParens.map(_._2) <* Parser.toEOL).map(_.padded) + val ex = Padding + .parser( + (P.string("export") + .soft ~ spaces) *> ExportedName.parser.itemsMaybeParens + .map(_._2) <* Parser.toEOL + ) + .map(_.padded) val body: P0[List[Statement]] = Statement.parser (pname, im, Parser.nonEmptyListToList(ex), body) .mapN { (p, i, e, b) => Package(p, i, e, b) } } - /** - * After having type checked the imports, we now type check the body - * in order to type check the exports - * - * This is used by test code - */ + /** After having type checked the imports, we now type check the body in order + * to type check the exports + * + * This is used by test code + */ def inferBody( - p: PackageName, - imps: List[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]], - stmts: List[Statement]): - Ior[NonEmptyList[PackageError], - Program[TypeEnv[Kind.Arg], TypedExpr[Declaration], List[Statement]]] = - inferBodyUnopt(p, imps, stmts).map { - case (fullTypeEnv, prog) => - TypedExprNormalization.normalizeProgram(p, fullTypeEnv, prog) - } + p: PackageName, + imps: List[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]], + stmts: List[Statement] + ): Ior[NonEmptyList[PackageError], Program[TypeEnv[Kind.Arg], TypedExpr[ + Declaration + ], List[Statement]]] = + inferBodyUnopt(p, imps, stmts).map { case (fullTypeEnv, prog) => + TypedExprNormalization.normalizeProgram(p, fullTypeEnv, prog) + } - /** - * Infer the types but do not optimize/normalize the lets - */ + /** Infer the types but do not optimize/normalize the lets + */ def inferBodyUnopt( - p: PackageName, - imps: List[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]], - stmts: List[Statement]): - Ior[NonEmptyList[PackageError], - (TypeEnv[Kind.Arg], Program[TypeEnv[Kind.Arg], TypedExpr[Declaration], List[Statement]])] = { + p: PackageName, + imps: List[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]], + stmts: List[Statement] + ): Ior[NonEmptyList[ + PackageError + ], (TypeEnv[Kind.Arg], Program[TypeEnv[Kind.Arg], TypedExpr[Declaration], List[Statement]])] = { // here we make a pass to get all the local names - val optProg = SourceConverter.toProgram(p, imps.map { i => i.copy(pack = i.pack.name) }, stmts) - .leftMap(_.map(PackageError.SourceConverterErrorIn(_, p): PackageError).toNonEmptyList) + val optProg = SourceConverter + .toProgram(p, imps.map { i => i.copy(pack = i.pack.name) }, stmts) + .leftMap( + _.map( + PackageError.SourceConverterErrorIn(_, p): PackageError + ).toNonEmptyList + ) lazy val typeDefRegions: Map[Type.Const.Defined, Region] = - stmts.iterator.collect { - case tds: TypeDefinitionStatement => - Type.Const.Defined(p, TypeName(tds.name)) -> tds.region - } - .toMap + stmts.iterator.collect { case tds: TypeDefinitionStatement => + Type.Const.Defined(p, TypeName(tds.name)) -> tds.region + }.toMap optProg.flatMap { case Program((importedTypeEnv, parsedTypeEnv), lets, extDefs, _) => - val inferVarianceParsed: Ior[NonEmptyList[PackageError], ParsedTypeEnv[Kind.Arg]] = - KindFormula.solveShapesAndKinds(importedTypeEnv, parsedTypeEnv.allDefinedTypes.reverse) - .bimap({ necError => - necError.map(PackageError.KindInferenceError(p, _, typeDefRegions)).toNonEmptyList - }, { infDTs => - ParsedTypeEnv(infDTs, parsedTypeEnv.externalDefs) - }) + val inferVarianceParsed + : Ior[NonEmptyList[PackageError], ParsedTypeEnv[Kind.Arg]] = + KindFormula + .solveShapesAndKinds( + importedTypeEnv, + parsedTypeEnv.allDefinedTypes.reverse + ) + .bimap( + { necError => + necError + .map(PackageError.KindInferenceError(p, _, typeDefRegions)) + .toNonEmptyList + }, + { infDTs => + ParsedTypeEnv(infDTs, parsedTypeEnv.externalDefs) + } + ) inferVarianceParsed.flatMap { parsedTypeEnv => /* * Check that all recursion is allowable */ val defRecursionCheck: ValidatedNel[PackageError, Unit] = - stmts.traverse_(DefRecursionCheck.checkStatement(_)) + stmts + .traverse_(DefRecursionCheck.checkStatement(_)) .leftMap { badRecursions => badRecursions.map(PackageError.RecursionError(p, _)) } @@ -241,19 +285,21 @@ object Package { val typeEnv: TypeEnv[Kind.Arg] = TypeEnv.fromParsed(parsedTypeEnv) /* - * These are values, including all constructor functions - * that have been imported, this includes local external - * defs - */ + * These are values, including all constructor functions + * that have been imported, this includes local external + * defs + */ val withFQN: Map[(Option[PackageName], Identifier), Type] = { val fqn = - Referant.fullyQualifiedImportedValues(imps)(_.name) + Referant + .fullyQualifiedImportedValues(imps)(_.name) .iterator .map { case ((p, n), t) => ((Some(p), n), t) } // these are local construtors/externals val localDefined = - typeEnv.localValuesOf(p) + typeEnv + .localValuesOf(p) .iterator .map { case (n, t) => ((Some(p), n), t) } @@ -263,11 +309,17 @@ object Package { val fullTypeEnv = importedTypeEnv ++ typeEnv val totalityCheck = lets - .traverse { case (_, _, expr) => TotalityCheck(fullTypeEnv).checkExpr(expr) } - .leftMap { errs => errs.map(PackageError.TotalityCheckError(p, _)) } + .traverse { case (_, _, expr) => + TotalityCheck(fullTypeEnv).checkExpr(expr) + } + .leftMap { errs => + errs.map(PackageError.TotalityCheckError(p, _)) + } - val inferenceEither = Infer.typeCheckLets(p, lets) - .runFully(withFQN, + val inferenceEither = Infer + .typeCheckLets(p, lets) + .runFully( + withFQN, Referant.typeConstructors(imps) ++ typeEnv.typeConstructors, fullTypeEnv.toKindMap ) @@ -278,12 +330,15 @@ object Package { .map(PackageError.TypeErrorIn(_, p)) val checkUnusedLets = - lets.traverse_ { case (_, _, expr) => - UnusedLetCheck.check(expr) - } - .leftMap { errs => - NonEmptyList.one(PackageError.UnusedLetError(p, errs.toNonEmptyList)) - } + lets + .traverse_ { case (_, _, expr) => + UnusedLetCheck.check(expr) + } + .leftMap { errs => + NonEmptyList.one( + PackageError.UnusedLetError(p, errs.toNonEmptyList) + ) + } /* * Checks accumulate errors, but have no return value: @@ -291,11 +346,13 @@ object Package { * error accumulation */ val checks = List( - defRecursionCheck, checkUnusedLets, totalityCheck - ) - .sequence_ + defRecursionCheck, + checkUnusedLets, + totalityCheck + ).sequence_ - val inference = Validated.fromEither(inferenceEither).leftMap(NonEmptyList.of(_)) + val inference = + Validated.fromEither(inferenceEither).leftMap(NonEmptyList.of(_)) Parallel[Ior[NonEmptyList[PackageError], *]] .parProductR(checks.toIor)(inference.toIor) @@ -303,13 +360,15 @@ object Package { } } - def checkValuesHaveExportedTypes[V](pn: PackageName, exports: List[ExportedName[Referant[V]]]): List[PackageError] = { - val exportedTypes: List[DefinedType[V]] = exports - .iterator + def checkValuesHaveExportedTypes[V]( + pn: PackageName, + exports: List[ExportedName[Referant[V]]] + ): List[PackageError] = { + val exportedTypes: List[DefinedType[V]] = exports.iterator .map(_.tag) .collect { case Referant.Constructor(dt, _) => dt - case Referant.DefinedT(dt) => dt + case Referant.DefinedT(dt) => dt } .toList .distinct @@ -317,18 +376,16 @@ object Package { val exportedTE = TypeEnv.fromDefinitions(exportedTypes) type Exp = ExportedName[Referant[V]] - val usedTypes: Iterator[(Type.Const, Exp, Type)] = exports - .iterator + val usedTypes: Iterator[(Type.Const, Exp, Type)] = exports.iterator .flatMap { n => n.tag match { case Referant.Value(t) => Iterator.single((t, n)) - case _ => Iterator.empty + case _ => Iterator.empty } } .flatMap { case (t, n) => Type.constantsOf(t).map((_, n, t)) } .filter { case (Type.Const.Defined(p, _), _, _) => p === pn } - def errorFor(t: (Type.Const, Exp, Type)): List[PackageError] = exportedTE.toDefinedType(t._1) match { case None => @@ -339,29 +396,39 @@ object Package { usedTypes.flatMap(errorFor).toList } - /** - * The parsed representation of the predef. - */ + /** The parsed representation of the predef. + */ lazy val predefPackage: Package.Parsed = parser(None).parse(Predef.predefString) match { case Right((_, pack)) => // Make function defs: - def paramType(n: Int) = (TypeRef.TypeVar(s"i$n"), Some(Kind.Arg(Variance.contra, Kind.Type))) - def makeFns(n: Int, - typeArgs: List[(TypeRef.TypeVar, Option[Kind.Arg])], - acc: List[Statement.ExternalStruct]): List[Statement.ExternalStruct] = + def paramType(n: Int) = + (TypeRef.TypeVar(s"i$n"), Some(Kind.Arg(Variance.contra, Kind.Type))) + def makeFns( + n: Int, + typeArgs: List[(TypeRef.TypeVar, Option[Kind.Arg])], + acc: List[Statement.ExternalStruct] + ): List[Statement.ExternalStruct] = if (n > Type.FnType.MaxSize) acc else { - val fn = Statement.ExternalStruct(Identifier.Constructor(s"Fn$n"), typeArgs)(Region(0, 1)) + val fn = Statement.ExternalStruct( + Identifier.Constructor(s"Fn$n"), + typeArgs + )(Region(0, 1)) val acc1 = fn :: acc makeFns(n + 1, paramType(n) :: typeArgs, acc1) } val out = (TypeRef.TypeVar("z"), Some(Kind.Arg(Variance.co, Kind.Type))) val allFns = makeFns(1, paramType(0) :: out :: Nil, Nil).reverse - val exported = allFns.map { extstr => ExportedName.TypeName(extstr.name, ()) } + val exported = allFns.map { extstr => + ExportedName.TypeName(extstr.name, ()) + } // Add functions into the predef - pack.copy(exports = exported ::: pack.exports, program = allFns ::: pack.program) + pack.copy( + exports = exported ::: pack.exports, + program = allFns ::: pack.program + ) case Left(err) => val idx = err.failedAtOffset val lm = LocationMap(Predef.predefString) @@ -370,4 +437,4 @@ object Package { System.err.println(errorMsg) sys.error(errorMsg) } -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/bykn/bosatsu/PackageError.scala b/core/src/main/scala/org/bykn/bosatsu/PackageError.scala index 06abc9661..ab19585a4 100644 --- a/core/src/main/scala/org/bykn/bosatsu/PackageError.scala +++ b/core/src/main/scala/org/bykn/bosatsu/PackageError.scala @@ -7,7 +7,10 @@ import rankn._ import LocationMap.Colorize sealed abstract class PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize): String + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ): String } object PackageError { @@ -15,22 +18,24 @@ object PackageError { // TODO: we should use the imports in each package to talk about // types in ways that are local to that package require(pack ne null) - tpes - .iterator - .map { t => - (t, Type.fullyResolvedDocument.document(t)) - } - .toMap + tpes.iterator.map { t => + (t, Type.fullyResolvedDocument.document(t)) + }.toMap } - def nearest[A](ident: Identifier, existing: Iterable[(Identifier, A)], count: Int): List[(Identifier, A)] = - existing - .iterator + def nearest[A]( + ident: Identifier, + existing: Iterable[(Identifier, A)], + count: Int + ): List[(Identifier, A)] = + existing.iterator .map { case (i, a) => val d = EditDistance.string(ident.asString, i.asString) (i, d, a) } - .filter(_._2 < ident.asString.length) // don't show things that require total edits + .filter( + _._2 < ident.asString.length + ) // don't show things that require total edits .toList .sortBy { case (_, d, _) => d } .distinct @@ -44,24 +49,28 @@ object PackageError { def headLine(packageName: PackageName, region: Option[Region]): Doc = { val (lm, sourceName) = getMapSrc(packageName) val suffix = (region.flatMap { r => lm.toLineCol(r.start) }) match { - case Some((line, col)) => s":${line + 1}:${col + 1}" - case None => "" + case Some((line, col)) => s":${line + 1}:${col + 1}" + case None => "" } Doc.text(s"in file: $sourceName$suffix, package ${packageName.asString}") } def getMapSrc(pack: PackageName): (LocationMap, String) = sm.get(pack) match { - case None => (emptyLocMap, "") + case None => (emptyLocMap, "") case Some(found) => found } } - - case class UnknownExport[A](ex: ExportedName[A], - in: PackageName, - lets: List[(Identifier.Bindable, RecursionKind, TypedExpr[Declaration])]) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { + case class UnknownExport[A]( + ex: ExportedName[A], + in: PackageName, + lets: List[(Identifier.Bindable, RecursionKind, TypedExpr[Declaration])] + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { val (lm, sourceName) = sourceMap.getMapSrc(in) val header = s"in $sourceName unknown export ${ex.name.sourceCodeRepr}" @@ -70,7 +79,10 @@ object PackageError { val candidates = nearest(ex.name, candidateMap, 3) .map { case (n, r) => - val pos = lm.toLineCol(r.start).map { case (l, c) => s":${l + 1}:${c + 1}" }.getOrElse("") + val pos = lm + .toLineCol(r.start) + .map { case (l, c) => s":${l + 1}:${c + 1}" } + .getOrElse("") s"${n.asString}$pos" } val candstr = candidates.mkString("\n\t", "\n\t", "\n") @@ -81,34 +93,51 @@ object PackageError { } } - case class PrivateTypeEscape[A](ex: ExportedName[A], - exType: Type, - in: PackageName, - privateType: Type.Const) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { + case class PrivateTypeEscape[A]( + ex: ExportedName[A], + exType: Type, + in: PackageName, + privateType: Type.Const + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { val (_, sourceName) = sourceMap.getMapSrc(in) val pt = Type.TyConst(privateType) val tpeMap = showTypes(in, exType :: pt :: Nil) - val first = s"in $sourceName export ${ex.name.sourceCodeRepr} of type ${tpeMap(exType).render(80)}" + val first = + s"in $sourceName export ${ex.name.sourceCodeRepr} of type ${tpeMap(exType).render(80)}" if (exType == pt) { s"$first has an unexported (private) type." - } - else { + } else { s"$first references an unexported (private) type ${tpeMap(pt).render(80)}." } } } - case class UnknownImportPackage[A, B, C](pack: PackageName, fromName: PackageName) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { + case class UnknownImportPackage[A, B, C]( + pack: PackageName, + fromName: PackageName + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { val (_, sourceName) = sourceMap.getMapSrc(fromName) s"in $sourceName package ${fromName.asString} imports unknown package ${pack.asString}" } } - case class DuplicatedImport(duplicates: NonEmptyList[(PackageName, ImportedName[Unit])]) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = - duplicates.sortBy(_._2.localName) + case class DuplicatedImport( + duplicates: NonEmptyList[(PackageName, ImportedName[Unit])] + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = + duplicates + .sortBy(_._2.localName) .toList .iterator .map { case (pack, imp) => @@ -120,55 +149,70 @@ object PackageError { // We could check if we forgot to export the name in the package and give that error case class UnknownImportName[A, B]( - in: PackageName, - importedPackage: PackageName, - letMap: Map[Identifier, Unit], - iname: ImportedName[A], - exports: List[ExportedName[B]]) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { - val ipname = importedPackage - - val (_, sourceName) = sourceMap.getMapSrc(in) - letMap - .get(iname.originalName) match { - case Some(_) => - s"in $sourceName package: ${ipname.asString} has ${iname.originalName.sourceCodeRepr} but it is not exported. Add to exports" - case None => - val near = nearest(iname.originalName, letMap, 3) - .map { case (n, _) => n.sourceCodeRepr } - .mkString(" Nearest: ", ", ", "") - s"in $sourceName package: ${ipname.asString} does not have name ${iname.originalName.sourceCodeRepr}.$near" - } + in: PackageName, + importedPackage: PackageName, + letMap: Map[Identifier, Unit], + iname: ImportedName[A], + exports: List[ExportedName[B]] + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { + val ipname = importedPackage + + val (_, sourceName) = sourceMap.getMapSrc(in) + letMap + .get(iname.originalName) match { + case Some(_) => + s"in $sourceName package: ${ipname.asString} has ${iname.originalName.sourceCodeRepr} but it is not exported. Add to exports" + case None => + val near = nearest(iname.originalName, letMap, 3) + .map { case (n, _) => n.sourceCodeRepr } + .mkString(" Nearest: ", ", ", "") + s"in $sourceName package: ${ipname.asString} does not have name ${iname.originalName.sourceCodeRepr}.$near" } } + } case class UnknownImportFromInterface[A, B]( - in: PackageName, - importingName: PackageName, - exportNames: List[Identifier], - iname: ImportedName[A], - exports: List[ExportedName[B]]) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { - - val exportMap = exportNames.map { e => (e, ()) }.toMap - - val near = Doc.text(" Nearest: ") + - (Doc.intercalate( + in: PackageName, + importingName: PackageName, + exportNames: List[Identifier], + iname: ImportedName[A], + exports: List[ExportedName[B]] + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { + + val exportMap = exportNames.map { e => (e, ()) }.toMap + + val near = Doc.text(" Nearest: ") + + (Doc + .intercalate( Doc.text(",") + Doc.line, nearest(iname.originalName, exportMap, 3) .map { ident => Doc.text(ident._1.sourceCodeRepr) } ) .nested(4) .grouped) - - (sourceMap.headLine(importingName, None) + Doc.hardLine + Doc.text( - s"does not have name ${iname.originalName}.") + near - ).render(80) - } + + (sourceMap.headLine(importingName, None) + Doc.hardLine + Doc.text( + s"does not have name ${iname.originalName}." + ) + near).render(80) } + } - case class CircularDependency[A, B, C](from: PackageName, path: NonEmptyList[PackageName]) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { + case class CircularDependency[A, B, C]( + from: PackageName, + path: NonEmptyList[PackageName] + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { val packs = from :: (path.toList) val msg = packs.map { p => val (_, src) = sourceMap.getMapSrc(p) @@ -179,35 +223,59 @@ object PackageError { } } - case class VarianceInferenceFailure(from: PackageName, failed: NonEmptyList[rankn.DefinedType[Unit]]) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { - s"failed to infer variance in ${from.asString} of " + failed.toList.map(_.name.ident.asString).sorted.mkString(", ") + case class VarianceInferenceFailure( + from: PackageName, + failed: NonEmptyList[rankn.DefinedType[Unit]] + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { + s"failed to infer variance in ${from.asString} of " + failed.toList + .map(_.name.ident.asString) + .sorted + .mkString(", ") } } - case class TypeErrorIn(tpeErr: Infer.Error, pack: PackageName) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { + case class TypeErrorIn(tpeErr: Infer.Error, pack: PackageName) + extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { val (lm, _) = sourceMap.getMapSrc(pack) val (teMessage, region) = tpeErr match { case Infer.Error.NotUnifiable(t0, t1, r0, r1) => val context0 = - if (r0 == r1) Doc.space // sometimes the region of the error is the same on right and left + if (r0 == r1) + Doc.space // sometimes the region of the error is the same on right and left else { - val m = lm.showRegion(r0, 2, errColor).getOrElse(Doc.str(r0)) // we should highlight the whole region + val m = lm + .showRegion(r0, 2, errColor) + .getOrElse(Doc.str(r0)) // we should highlight the whole region Doc.hardLine + m + Doc.hardLine } val context1 = - lm.showRegion(r1, 2, errColor).getOrElse(Doc.str(r1)) // we should highlight the whole region + lm.showRegion(r1, 2, errColor) + .getOrElse(Doc.str(r1)) // we should highlight the whole region val fnHint = (t0, t1) match { - case (Type.RootConst(Type.FnType(_, leftSize)), - Type.RootConst(Type.FnType(_, rightSize))) => + case ( + Type.RootConst(Type.FnType(_, leftSize)), + Type.RootConst(Type.FnType(_, rightSize)) + ) => // both are functions - def args(n: Int) = if (n == 1) "one argument" else s"$n arguments" - Doc.text(s"hint: the first type is a function with ${args(leftSize)} and the second is a function with ${args(rightSize)}.") + Doc.hardLine + def args(n: Int) = + if (n == 1) "one argument" else s"$n arguments" + Doc.text( + s"hint: the first type is a function with ${args(leftSize)} and the second is a function with ${args(rightSize)}." + ) + Doc.hardLine case (Type.Fun(_, _), _) | (_, Type.Fun(_, _)) => - Doc.text("hint: this often happens when you apply the wrong number of arguments to a function.") + Doc.hardLine + Doc.text( + "hint: this often happens when you apply the wrong number of arguments to a function." + ) + Doc.hardLine case _ => Doc.empty } @@ -219,26 +287,36 @@ object PackageError { (doc, Some(r0)) case Infer.Error.VarNotInScope((_, name), scope, region) => - val ctx = lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) + val ctx = + lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) val candidates: List[String] = nearest(name, scope.map { case ((_, n), _) => (n, ()) }, 3) .map { case (n, _) => n.asString } val cmessage = - if (candidates.nonEmpty) candidates.mkString("\nClosest: ", ", ", ".\n") + if (candidates.nonEmpty) + candidates.mkString("\nClosest: ", ", ", ".\n") else "" val qname = "\"" + name.sourceCodeRepr + "\"" - (Doc.text("name ") + Doc.text(qname) + Doc.text(" unknown.") + Doc.text(cmessage) + Doc.hardLine + - ctx, Some(region)) + ( + Doc.text("name ") + Doc.text(qname) + Doc.text(" unknown.") + Doc + .text(cmessage) + Doc.hardLine + + ctx, + Some(region) + ) case Infer.Error.SubsumptionCheckFailure(t0, t1, r0, r1, _) => val context0 = - if (r0 == r1) Doc.space // sometimes the region of the error is the same on right and left + if (r0 == r1) + Doc.space // sometimes the region of the error is the same on right and left else { - val m = lm.showRegion(r0, 2, errColor).getOrElse(Doc.str(r0)) // we should highlight the whole region + val m = lm + .showRegion(r0, 2, errColor) + .getOrElse(Doc.str(r0)) // we should highlight the whole region Doc.hardLine + m + Doc.hardLine } val context1 = - lm.showRegion(r1, 2, errColor).getOrElse(Doc.str(r1)) // we should highlight the whole region + lm.showRegion(r1, 2, errColor) + .getOrElse(Doc.str(r1)) // we should highlight the whole region val tmap = showTypes(pack, List(t0, t1)) val doc = Doc.text("type ") + tmap(t0) + context0 + @@ -246,8 +324,12 @@ object PackageError { context1 (doc, Some(r0)) - case uc@Infer.Error.UnknownConstructor((_, n), region, _) => - val near = nearest(n, uc.knownConstructors.map { case (_, n) => (n, ()) }.toMap, 3) + case uc @ Infer.Error.UnknownConstructor((_, n), region, _) => + val near = nearest( + n, + uc.knownConstructors.map { case (_, n) => (n, ()) }.toMap, + 3 + ) .map { case (n, _) => n.asString } val nearStr = @@ -255,7 +337,10 @@ object PackageError { else near.mkString(", nearest: ", ", ", "") val context = - lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) // we should highlight the whole region + lm.showRegion(region, 2, errColor) + .getOrElse( + Doc.str(region) + ) // we should highlight the whole region val doc = Doc.text("unknown constructor ") + Doc.text(n.asString) + Doc.text(nearStr) + Doc.hardLine + context @@ -263,9 +348,14 @@ object PackageError { case Infer.Error.KindCannotTyApply(applied, region) => val tmap = showTypes(pack, applied :: Nil) val context = - lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) // we should highlight the whole region + lm.showRegion(region, 2, errColor) + .getOrElse( + Doc.str(region) + ) // we should highlight the whole region val doc = Doc.text("kind error: for kind of the left of ") + - tmap(applied) + Doc.text(" is *. Cannot apply to kind *.") + Doc.hardLine + + tmap(applied) + Doc.text( + " is *. Cannot apply to kind *." + ) + Doc.hardLine + context (doc, Some(region)) @@ -274,99 +364,134 @@ object PackageError { val rightT = applied.arg val tmap = showTypes(pack, applied :: leftT :: rightT :: Nil) val context = - lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) - val doc = Doc.text("kind error: ") + Doc.text("the type: ") + tmap(applied) + - Doc.text(" is invalid because the left ") + tmap(leftT) + Doc.text(" has kind ") + Kind.toDoc(leftK) + - Doc.text(" and the right ") + tmap(rightT) + Doc.text(" has kind ") + Kind.toDoc(rightK) + - Doc.text(s" but left cannot accept the kind of the right:") + - Doc.hardLine + - context + lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) + val doc = + Doc.text("kind error: ") + Doc.text("the type: ") + tmap(applied) + + Doc.text(" is invalid because the left ") + tmap(leftT) + Doc + .text(" has kind ") + Kind.toDoc(leftK) + + Doc.text(" and the right ") + tmap(rightT) + Doc.text( + " has kind " + ) + Kind.toDoc(rightK) + + Doc.text(s" but left cannot accept the kind of the right:") + + Doc.hardLine + + context (doc, Some(region)) - case Infer.Error.KindMetaMismatch(meta, rightT, rightK, metaR, rightR) => + case Infer.Error.KindMetaMismatch( + meta, + rightT, + rightK, + metaR, + rightR + ) => val tmap = showTypes(pack, meta :: rightT :: Nil) val context0 = - lm.showRegion(metaR, 2, errColor).getOrElse(Doc.str(metaR)) // we should highlight the whole region + lm.showRegion(metaR, 2, errColor) + .getOrElse(Doc.str(metaR)) // we should highlight the whole region val context1 = { if (metaR != rightR) { Doc.text(" at: ") + Doc.hardLine + - lm.showRegion(rightR, 2, errColor).getOrElse(Doc.str(rightR)) + // we should highlight the whole region - Doc.hardLine - } - else { + lm.showRegion(rightR, 2, errColor) + .getOrElse( + Doc.str(rightR) + ) + // we should highlight the whole region + Doc.hardLine + } else { Doc.empty } } - val doc = Doc.text("kind error: ") + Doc.text("the type: ") + tmap(meta) + - Doc.text(" of kind: ") + Kind.toDoc(meta.toMeta.kind) + Doc.text(" at: ") + Doc.hardLine + - context0 + Doc.hardLine + Doc.hardLine + - Doc.text("cannot be unified with the type ") + tmap(rightT) + - Doc.text(" of kind: ") + Kind.toDoc(rightK) + context1 + - Doc.hardLine + - Doc.text("because the first kind does not subsume the second.") + val doc = + Doc.text("kind error: ") + Doc.text("the type: ") + tmap(meta) + + Doc.text(" of kind: ") + Kind.toDoc(meta.toMeta.kind) + Doc.text( + " at: " + ) + Doc.hardLine + + context0 + Doc.hardLine + Doc.hardLine + + Doc.text("cannot be unified with the type ") + tmap(rightT) + + Doc.text(" of kind: ") + Kind.toDoc(rightK) + context1 + + Doc.hardLine + + Doc.text("because the first kind does not subsume the second.") (doc, Some(metaR)) case Infer.Error.UnexpectedMeta(meta, in, metaR, rightR) => val tymeta = Type.TyMeta(meta) val tmap = showTypes(pack, tymeta :: in :: Nil) val context0 = - lm.showRegion(metaR, 2, errColor).getOrElse(Doc.str(metaR)) // we should highlight the whole region + lm.showRegion(metaR, 2, errColor) + .getOrElse(Doc.str(metaR)) // we should highlight the whole region val context1 = { if (metaR != rightR) { Doc.text(" at: ") + Doc.hardLine + - lm.showRegion(rightR, 2, errColor).getOrElse(Doc.str(rightR)) + // we should highlight the whole region - Doc.hardLine - } - else { + lm.showRegion(rightR, 2, errColor) + .getOrElse( + Doc.str(rightR) + ) + // we should highlight the whole region + Doc.hardLine + } else { Doc.empty } } val doc = Doc.text("Unexpected unknown: the type: ") + tmap(tymeta) + - Doc.text(" of kind: ") + Kind.toDoc(meta.kind) + Doc.text(" at: ") + Doc.hardLine + + Doc.text(" of kind: ") + Kind.toDoc(meta.kind) + Doc.text( + " at: " + ) + Doc.hardLine + context0 + Doc.hardLine + Doc.hardLine + Doc.text("inside the type ") + tmap(in) + context1 + Doc.hardLine + - Doc.text("this sometimes happens when a function arg has been omitted, or an illegal recursive type or function.") + Doc.text( + "this sometimes happens when a function arg has been omitted, or an illegal recursive type or function." + ) (doc, Some(metaR)) - case Infer.Error.KindNotUnifiable(leftK, leftT, rightK, rightT, leftR, rightR) => + case Infer.Error.KindNotUnifiable( + leftK, + leftT, + rightK, + rightT, + leftR, + rightR + ) => val tStr = showTypes(pack, leftT :: rightT :: Nil) val context0 = - lm.showRegion(leftR, 2, errColor).getOrElse(Doc.str(leftR)) + lm.showRegion(leftR, 2, errColor).getOrElse(Doc.str(leftR)) val context1 = { if (leftR != rightR) { Doc.text(" at: ") + Doc.hardLine + - lm.showRegion(rightR, 2, errColor).getOrElse(Doc.str(rightR)) - } - else { + lm.showRegion(rightR, 2, errColor).getOrElse(Doc.str(rightR)) + } else { Doc.empty } } val doc = Doc.text("kind mismatch error: ") + - tStr(leftT) + Doc.text(": ") + Kind.toDoc(leftK) + Doc.text(" at:") + Doc.hardLine + context0 + + tStr(leftT) + Doc.text(": ") + Kind.toDoc(leftK) + Doc.text( + " at:" + ) + Doc.hardLine + context0 + Doc.text(" cannot be unified with kind: ") + tStr(rightT) + Doc.text(": ") + Kind.toDoc(rightK) + context1 (doc, Some(leftR)) - case Infer.Error.NotPolymorphicEnough(tpe, _, _, region) => + case Infer.Error.NotPolymorphicEnough(tpe, _, _, region) => val tmap = showTypes(pack, tpe :: Nil) val context = - lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) + lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) - (Doc.text("the type ") + tmap(tpe) + Doc.text(" is not polymorphic enough") + Doc.hardLine + context, Some(region)) - case Infer.Error.ArityMismatch(leftA, leftR, rightA, rightR) => + ( + Doc.text("the type ") + tmap(tpe) + Doc.text( + " is not polymorphic enough" + ) + Doc.hardLine + context, + Some(region) + ) + case Infer.Error.ArityMismatch(leftA, leftR, rightA, rightR) => val context0 = - lm.showRegion(leftR, 2, errColor).getOrElse(Doc.str(leftR)) + lm.showRegion(leftR, 2, errColor).getOrElse(Doc.str(leftR)) val context1 = { if (leftR != rightR) { Doc.text(" at: ") + Doc.hardLine + - lm.showRegion(rightR, 2, errColor).getOrElse(Doc.str(rightR)) - } - else { + lm.showRegion(rightR, 2, errColor).getOrElse(Doc.str(rightR)) + } else { Doc.empty } } @@ -374,57 +499,89 @@ object PackageError { def args(n: Int) = if (n == 1) "one argument" else s"$n arguments" - (Doc.text(s"function with ${args(leftA)} at:") + Doc.hardLine + context0 + - Doc.text(s" does not match function with ${args(rightA)}") + context1, Some(leftR)) + ( + Doc.text( + s"function with ${args(leftA)} at:" + ) + Doc.hardLine + context0 + + Doc.text( + s" does not match function with ${args(rightA)}" + ) + context1, + Some(leftR) + ) case Infer.Error.ArityTooLarge(found, max, region) => val context = - lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) + lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) - (Doc.text(s"function with $found arguments is too large. Maximum function argument count is $max.") + Doc.hardLine + context, - Some(region)) + ( + Doc.text( + s"function with $found arguments is too large. Maximum function argument count is $max." + ) + Doc.hardLine + context, + Some(region) + ) case Infer.Error.UnexpectedBound(bound, _, reg, _) => val tyvar = Type.TyVar(bound) val tmap = showTypes(pack, tyvar :: Nil) val context = - lm.showRegion(reg, 2, errColor).getOrElse(Doc.str(reg)) + lm.showRegion(reg, 2, errColor).getOrElse(Doc.str(reg)) - (Doc.text("unexpected bound: ") + tmap(tyvar) + Doc.hardLine + context, Some(reg)) + ( + Doc.text("unexpected bound: ") + tmap( + tyvar + ) + Doc.hardLine + context, + Some(reg) + ) case Infer.Error.UnionPatternBindMismatch(_, names, region) => val context = - lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) + lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) val uniqueSets = graph.Tree.distinctBy(names)(_.toSet) - val uniqs = Doc.intercalate(Doc.char(',') + Doc.line, + val uniqs = Doc.intercalate( + Doc.char(',') + Doc.line, uniqueSets.toList.map { names => - Doc.text(names.iterator.map(_.sourceCodeRepr).mkString("[", ", ", "]")) + Doc.text( + names.iterator.map(_.sourceCodeRepr).mkString("[", ", ", "]") + ) } ) - (Doc.text("not all union elements bind the same names: ") + - (Doc.line + uniqs + context).nested(4).grouped, - Some(region)) + ( + Doc.text("not all union elements bind the same names: ") + + (Doc.line + uniqs + context).nested(4).grouped, + Some(region) + ) case Infer.Error.UnknownDefined(const, reg) => val tpe = Type.TyConst(const) val tmap = showTypes(pack, tpe :: Nil) val context = - lm.showRegion(reg, 2, errColor).getOrElse(Doc.str(reg)) + lm.showRegion(reg, 2, errColor).getOrElse(Doc.str(reg)) - (Doc.text("unknown type: ") + tmap(tpe) + Doc.hardLine + context, Some(reg)) + ( + Doc.text("unknown type: ") + tmap(tpe) + Doc.hardLine + context, + Some(reg) + ) case ie: Infer.Error.InternalError => (Doc.text(ie.message), Some(ie.region)) } - val h = sourceMap.headLine(pack, region) + val h = sourceMap.headLine(pack, region) (h + Doc.hardLine + teMessage).render(80) } } - case class SourceConverterErrorIn(err: SourceConverter.Error, pack: PackageName) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { + case class SourceConverterErrorIn( + err: SourceConverter.Error, + pack: PackageName + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { val (lm, _) = sourceMap.getMapSrc(pack) val msg = { val context = lm.showRegion(err.region, 2, errColor) - .getOrElse(Doc.str(err.region)) // we should highlight the whole region + .getOrElse( + Doc.str(err.region) + ) // we should highlight the whole region Doc.text(err.message) + Doc.hardLine + context } @@ -434,16 +591,27 @@ object PackageError { } } - case class TotalityCheckError(pack: PackageName, err: TotalityCheck.ExprError[Declaration]) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { + case class TotalityCheckError( + pack: PackageName, + err: TotalityCheck.ExprError[Declaration] + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { val (lm, _) = sourceMap.getMapSrc(pack) val region = err.matchExpr.tag.region val context1 = - lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) // we should highlight the whole region + lm.showRegion(region, 2, errColor) + .getOrElse(Doc.str(region)) // we should highlight the whole region val teMessage = err match { case TotalityCheck.NonTotalMatch(_, missing) => - val allTypes = missing.traverse(_.traverseType { t => Writer(Chain.one(t), ()) }) - .run._1.toList.distinct + val allTypes = missing + .traverse(_.traverseType { t => Writer(Chain.one(t), ()) }) + .run + ._1 + .toList + .distinct val showT = showTypes(pack, allTypes) val doc = Pattern.compiledDocument(Document.instance[Type] { t => @@ -451,11 +619,17 @@ object PackageError { }) Doc.text("non-total match, missing: ") + - (Doc.intercalate(Doc.char(',') + Doc.lineOrSpace, - missing.toList.map(doc.document(_)))) + (Doc.intercalate( + Doc.char(',') + Doc.lineOrSpace, + missing.toList.map(doc.document(_)) + )) case TotalityCheck.UnreachableBranches(_, unreachableBranches) => - val allTypes = unreachableBranches.traverse(_.traverseType { t => Writer(Chain.one(t), ()) }) - .run._1.toList.distinct + val allTypes = unreachableBranches + .traverse(_.traverseType { t => Writer(Chain.one(t), ()) }) + .run + ._1 + .toList + .distinct val showT = showTypes(pack, allTypes) val doc = Pattern.compiledDocument(Document.instance[Type] { t => @@ -463,13 +637,17 @@ object PackageError { }) Doc.text("unreachable branches: ") + - (Doc.intercalate(Doc.char(',') + Doc.lineOrSpace, - unreachableBranches.toList.map(doc.document(_)))) + (Doc.intercalate( + Doc.char(',') + Doc.lineOrSpace, + unreachableBranches.toList.map(doc.document(_)) + )) case TotalityCheck.InvalidPattern(_, err) => import TotalityCheck._ err match { case ArityMismatch((_, n), _, _, exp, found) => - Doc.text(s"arity mismatch: ${n.asString} expected $exp parameters, found $found") + Doc.text( + s"arity mismatch: ${n.asString} expected $exp parameters, found $found" + ) case UnknownConstructor((_, n), _, _) => Doc.text(s"unknown constructor: ${n.asString}") case InvalidStrPat(pat, _) => @@ -478,8 +656,10 @@ object PackageError { Doc.text(" (adjacent bindings aren't allowed)") case MultipleSplicesInPattern(_, _) => // TODO: get printing of compiled patterns working well - //val docp = Document[Pattern.Parsed].document(Pattern.ListPat(pat)) + - Doc.text("multiple splices in pattern, only one per match allowed") + // val docp = Document[Pattern.Parsed].document(Pattern.ListPat(pat)) + + Doc.text( + "multiple splices in pattern, only one per match allowed" + ) } } val prefix = sourceMap.headLine(pack, Some(region)) @@ -490,27 +670,43 @@ object PackageError { } } - case class UnusedLetError(pack: PackageName, errs: NonEmptyList[(Identifier.Bindable, Region)]) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { + case class UnusedLetError( + pack: PackageName, + errs: NonEmptyList[(Identifier.Bindable, Region)] + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { val (lm, _) = sourceMap.getMapSrc(pack) val docs = errs .sortBy(_._2) .map { case (bn, region) => - val rdoc = lm.showRegion(region, 2, errColor).getOrElse(Doc.str(region)) // we should highlight the whole region + val rdoc = lm + .showRegion(region, 2, errColor) + .getOrElse(Doc.str(region)) // we should highlight the whole region val message = Doc.text("unused let binding: " + bn.sourceCodeRepr) message + Doc.hardLine + rdoc } val packDoc = sourceMap.headLine(pack, Some(errs.head._2)) val line2 = Doc.hardLine + Doc.hardLine - (packDoc + (line2 + Doc.intercalate(line2, docs.toList)).nested(2)).render(80) + (packDoc + (line2 + Doc.intercalate(line2, docs.toList)).nested(2)) + .render(80) } } - case class RecursionError(pack: PackageName, err: DefRecursionCheck.RecursionError) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { + case class RecursionError( + pack: PackageName, + err: DefRecursionCheck.RecursionError + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { val (lm, _) = sourceMap.getMapSrc(pack) - val ctx = lm.showRegion(err.region, 2, errColor) + val ctx = lm + .showRegion(err.region, 2, errColor) .getOrElse(Doc.str(err.region)) // we should highlight the whole region val errMessage = err.message // TODO use the sourceMap/regions in RecursionError @@ -522,17 +718,22 @@ object PackageError { } } - case class DuplicatedPackageError(dups: NonEmptyMap[PackageName, (String, NonEmptyList[String])]) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { + case class DuplicatedPackageError( + dups: NonEmptyMap[PackageName, (String, NonEmptyList[String])] + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { val packDoc = Doc.text("package ") val dupInDoc = Doc.text(" duplicated in ") - val dupMessages = dups - .toSortedMap + val dupMessages = dups.toSortedMap .map { case (pname, (one, nelist)) => - val dupsrcs = Doc.intercalate(Doc.comma + Doc.lineOrSpace, - (one :: nelist.toList) - .sorted - .map(Doc.text(_)) + val dupsrcs = Doc + .intercalate( + Doc.comma + Doc.lineOrSpace, + (one :: nelist.toList).sorted + .map(Doc.text(_)) ) .nested(4) packDoc + Doc.text(pname.asString) + dupInDoc + dupsrcs @@ -542,46 +743,76 @@ object PackageError { } } - case class KindInferenceError(pack: PackageName, kindError: KindFormula.Error, regions: Map[Type.Const.Defined, Region]) extends PackageError { - def message(sourceMap: Map[PackageName, (LocationMap, String)], errColor: Colorize) = { + case class KindInferenceError( + pack: PackageName, + kindError: KindFormula.Error, + regions: Map[Type.Const.Defined, Region] + ) extends PackageError { + def message( + sourceMap: Map[PackageName, (LocationMap, String)], + errColor: Colorize + ) = { val (lm, _) = sourceMap.getMapSrc(pack) kindError match { - case KindFormula.Error.Unsatisfiable(_, _, _) => - val prefix = sourceMap.headLine(pack, None) + case KindFormula.Error.Unsatisfiable(_, _, _) => + val prefix = sourceMap.headLine(pack, None) (prefix + Doc.text(s": $kindError")).render(80) case KindFormula.Error.FromShapeError(se) => se match { case Shape.UnificationError(dt, cons, left, right) => val region = regions(dt.toTypeConst) val prefix = sourceMap.headLine(pack, Some(region)) - val ctx = lm.showRegion(region, 2, errColor) - .getOrElse(Doc.str(region)) // we should highlight the whole region - (prefix + Doc.hardLine + Doc.text("shape error: expected ") + Shape.shapeDoc(left) + Doc.text(" and ") + Shape.shapeDoc(right) + - Doc.text(s" to match in the constructor ${cons.name.sourceCodeRepr}") + Doc.hardLine + Doc.hardLine + + val ctx = lm + .showRegion(region, 2, errColor) + .getOrElse( + Doc.str(region) + ) // we should highlight the whole region + (prefix + Doc.hardLine + Doc.text( + "shape error: expected " + ) + Shape.shapeDoc(left) + Doc.text(" and ") + Shape.shapeDoc( + right + ) + + Doc.text( + s" to match in the constructor ${cons.name.sourceCodeRepr}" + ) + Doc.hardLine + Doc.hardLine + ctx).render(80) case Shape.ShapeMismatch(dt, cons, outer, tyApp, right) => val tmap = showTypes(pack, outer :: tyApp :: Nil) val region = regions(dt.toTypeConst) val prefix = sourceMap.headLine(pack, Some(region)) - val ctx = lm.showRegion(region, 2, errColor) - .getOrElse(Doc.str(region)) // we should highlight the whole region + val ctx = lm + .showRegion(region, 2, errColor) + .getOrElse( + Doc.str(region) + ) // we should highlight the whole region val typeDoc = - if (outer != tyApp) (tmap(outer) + Doc.text(" at application ") + tmap(tyApp)) + if (outer != tyApp) + (tmap(outer) + Doc.text(" at application ") + tmap(tyApp)) else tmap(outer) - (prefix + Doc.text(" shape error: expected ") + Shape.shapeDoc(right) + Doc.text(" -> ?") + Doc.text(" but found * ") + - Doc.text(s"in the constructor ${cons.name.sourceCodeRepr} inside type ") + - typeDoc + + (prefix + Doc.text(" shape error: expected ") + Shape.shapeDoc( + right + ) + Doc.text(" -> ?") + Doc.text(" but found * ") + + Doc.text( + s"in the constructor ${cons.name.sourceCodeRepr} inside type " + ) + + typeDoc + Doc.hardLine + Doc.hardLine + ctx).render(80) case Shape.FinishFailure(dt, left, right) => val region = regions(dt.toTypeConst) - val tdoc = showTypes(pack, dt.toTypeTyConst :: Nil)(dt.toTypeTyConst) + val tdoc = + showTypes(pack, dt.toTypeTyConst :: Nil)(dt.toTypeTyConst) val prefix = sourceMap.headLine(pack, Some(region)) - val message = Doc.text("in type ") + tdoc + Doc.text(" could not unify shapes: ") + Shape.shapeDoc(left) + Doc.text(" and ") + + val message = Doc.text("in type ") + tdoc + Doc.text( + " could not unify shapes: " + ) + Shape.shapeDoc(left) + Doc.text(" and ") + Shape.shapeDoc(right) - val ctx = lm.showRegion(region, 2, errColor) - .getOrElse(Doc.str(region)) // we should highlight the whole region + val ctx = lm + .showRegion(region, 2, errColor) + .getOrElse( + Doc.str(region) + ) // we should highlight the whole region (prefix + Doc.hardLine + message + Doc.hardLine + ctx).render(80) case Shape.ShapeLoop(dt, tpe, _) => val region = regions(dt.toTypeConst) @@ -592,10 +823,16 @@ object PackageError { val tdocs = showTypes(pack, dt.toTypeTyConst :: tpe2 :: Nil) val prefix = sourceMap.headLine(pack, Some(region)) - val message = Doc.text("in type ") + tdocs(dt.toTypeTyConst) + Doc.text(" cyclic dependency encountered in ") + - tdocs(tpe2) - val ctx = lm.showRegion(region, 2, errColor) - .getOrElse(Doc.str(region)) // we should highlight the whole region + val message = + Doc.text("in type ") + tdocs(dt.toTypeTyConst) + Doc.text( + " cyclic dependency encountered in " + ) + + tdocs(tpe2) + val ctx = lm + .showRegion(region, 2, errColor) + .getOrElse( + Doc.str(region) + ) // we should highlight the whole region (prefix + Doc.hardLine + message + Doc.hardLine + ctx).render(80) case Shape.UnboundVar(dt, cfn, v) => val region = regions(dt.toTypeConst) @@ -603,14 +840,19 @@ object PackageError { val tdocs = showTypes(pack, dt.toTypeTyConst :: tpe2 :: Nil) val prefix = sourceMap.headLine(pack, Some(region)) - val cfnMsg = if (dt.isStruct) Doc.empty else { - Doc.text(s" in constructor ${cfn.name.sourceCodeRepr} ") - } + val cfnMsg = + if (dt.isStruct) Doc.empty + else { + Doc.text(s" in constructor ${cfn.name.sourceCodeRepr} ") + } val message = Doc.text("in type ") + tdocs(dt.toTypeTyConst) + Doc.text(" unbound type variable ") + tdocs(tpe2) + cfnMsg - val ctx = lm.showRegion(region, 2, errColor) - .getOrElse(Doc.str(region)) // we should highlight the whole region + val ctx = lm + .showRegion(region, 2, errColor) + .getOrElse( + Doc.str(region) + ) // we should highlight the whole region (prefix + Doc.hardLine + message + Doc.hardLine + ctx).render(80) case Shape.UnknownConst(dt, cfn, c) => val region = regions(dt.toTypeConst) @@ -618,17 +860,22 @@ object PackageError { val tdocs = showTypes(pack, dt.toTypeTyConst :: tpe2 :: Nil) val prefix = sourceMap.headLine(pack, Some(region)) - val cfnMsg = if (dt.isStruct) Doc.empty else { - Doc.text(s" in constructor ${cfn.name.sourceCodeRepr} ") - } + val cfnMsg = + if (dt.isStruct) Doc.empty + else { + Doc.text(s" in constructor ${cfn.name.sourceCodeRepr} ") + } val message = Doc.text("in type ") + tdocs(dt.toTypeTyConst) + Doc.text(" unknown type ") + tdocs(tpe2) + cfnMsg - val ctx = lm.showRegion(region, 2, errColor) - .getOrElse(Doc.str(region)) // we should highlight the whole region + val ctx = lm + .showRegion(region, 2, errColor) + .getOrElse( + Doc.str(region) + ) // we should highlight the whole region (prefix + Doc.hardLine + message + Doc.hardLine + ctx).render(80) } } } } -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/bykn/bosatsu/PackageMap.scala b/core/src/main/scala/org/bykn/bosatsu/PackageMap.scala index 0b86221bc..be1e06afe 100644 --- a/core/src/main/scala/org/bykn/bosatsu/PackageMap.scala +++ b/core/src/main/scala/org/bykn/bosatsu/PackageMap.scala @@ -2,7 +2,15 @@ package org.bykn.bosatsu import org.bykn.bosatsu.graph.Memoize import cats.{Foldable, Monad, Show} -import cats.data.{Ior, IorT, NonEmptyList, NonEmptyMap, Validated, ValidatedNel, ReaderT} +import cats.data.{ + Ior, + IorT, + NonEmptyList, + NonEmptyMap, + Validated, + ValidatedNel, + ReaderT +} import scala.collection.immutable.SortedMap import Identifier.Constructor @@ -12,45 +20,56 @@ import rankn.{DataRepr, TypeEnv} import cats.implicits._ -case class PackageMap[A, B, C, +D](toMap: SortedMap[PackageName, Package[A, B, C, D]]) { +case class PackageMap[A, B, C, +D]( + toMap: SortedMap[PackageName, Package[A, B, C, D]] +) { def +[D1 >: D](pack: Package[A, B, C, D1]): PackageMap[A, B, C, D1] = PackageMap(toMap + (pack.name -> pack)) - def ++[D1 >: D](packs: Iterable[Package[A, B, C, D1]]): PackageMap[A, B, C, D1] = + def ++[D1 >: D]( + packs: Iterable[Package[A, B, C, D1]] + ): PackageMap[A, B, C, D1] = packs.foldLeft(this: PackageMap[A, B, C, D1])(_ + _) - def getDataRepr(implicit ev: D <:< Program[TypeEnv[Any], Any, Any]): (PackageName, Constructor) => Option[DataRepr] = { - (pname, cons) => - toMap.get(pname) - .flatMap { pack => - ev(pack.program) - .types - .getConstructor(pname, cons) - .map(_._1.dataRepr(cons)) - } - } - - def allExternals(implicit ev: D <:< Program[TypeEnv[Any], Any, Any]): Map[PackageName, List[Identifier.Bindable]] = + def getDataRepr(implicit + ev: D <:< Program[TypeEnv[Any], Any, Any] + ): (PackageName, Constructor) => Option[DataRepr] = { (pname, cons) => toMap - .iterator - .map { case (name, pack) => - (name, ev(pack.program).externalDefs) + .get(pname) + .flatMap { pack => + ev(pack.program).types + .getConstructor(pname, cons) + .map(_._1.dataRepr(cons)) } - .toMap + } + + def allExternals(implicit + ev: D <:< Program[TypeEnv[Any], Any, Any] + ): Map[PackageName, List[Identifier.Bindable]] = + toMap.iterator.map { case (name, pack) => + (name, ev(pack.program).externalDefs) + }.toMap } object PackageMap { def empty[A, B, C, D]: PackageMap[A, B, C, D] = PackageMap(SortedMap.empty) - def fromIterable[A, B, C, D](ps: Iterable[Package[A, B, C, D]]): PackageMap[A, B, C, D] = + def fromIterable[A, B, C, D]( + ps: Iterable[Package[A, B, C, D]] + ): PackageMap[A, B, C, D] = empty[A, B, C, D] ++ ps import Package.FixPackage type MapF3[A, B, C] = PackageMap[FixPackage[A, B, C], A, B, C] type MapF2[A, B] = MapF3[A, A, B] - type ParsedImp = PackageMap[PackageName, Unit, Unit, (List[Statement], ImportMap[PackageName, Unit])] + type ParsedImp = PackageMap[ + PackageName, + Unit, + Unit, + (List[Statement], ImportMap[PackageName, Unit]) + ] type Resolved = MapF2[Unit, (List[Statement], ImportMap[PackageName, Unit])] type Typed[+T] = PackageMap[ Package.Interface, @@ -62,7 +81,7 @@ object PackageMap { Any ] ] - + type SourceMap = Map[PackageName, (LocationMap, String)] // convenience for type inference @@ -70,65 +89,97 @@ object PackageMap { type Inferred = Typed[Declaration] - /** - * This builds a DAG of actual packages where names have been replaced by the fully resolved - * packages - */ - def resolvePackages[A, B, C](map: PackageMap[PackageName, A, B, C], ifs: List[Package.Interface]): ValidatedNel[PackageError, MapF3[A, B, C]] = { + /** This builds a DAG of actual packages where names have been replaced by the + * fully resolved packages + */ + def resolvePackages[A, B, C]( + map: PackageMap[PackageName, A, B, C], + ifs: List[Package.Interface] + ): ValidatedNel[PackageError, MapF3[A, B, C]] = { val interfaceMap = ifs.iterator.map { iface => (iface.name, iface) }.toMap def getPackage( - i: Import[PackageName, A], - from: Package[PackageName, A, B, C]): ValidatedNel[PackageError, Import[Either[Package.Interface, Package[PackageName, A, B, C]], A]] = - map.toMap.get(i.pack) match { - case Some(pack) => Validated.valid(Import(Right(pack), i.items)) - case None => - interfaceMap.get(i.pack) match { - case Some(iface) => - Validated.valid(Import(Left(iface), i.items)) - case None => - Validated.invalidNel(PackageError.UnknownImportPackage(i.pack, from.name)) - } - } + i: Import[PackageName, A], + from: Package[PackageName, A, B, C] + ): ValidatedNel[PackageError, Import[ + Either[Package.Interface, Package[PackageName, A, B, C]], + A + ]] = + map.toMap.get(i.pack) match { + case Some(pack) => Validated.valid(Import(Right(pack), i.items)) + case None => + interfaceMap.get(i.pack) match { + case Some(iface) => + Validated.valid(Import(Left(iface), i.items)) + case None => + Validated.invalidNel( + PackageError.UnknownImportPackage(i.pack, from.name) + ) + } + } type PackageFix = Package[FixPackage[A, B, C], A, B, C] // We use the ReaderT to build the list of imports we are on // to detect circular dependencies, if the current package imports itself transitively we // want to report the full path - val step: Package[PackageName, A, B, C] => ReaderT[Either[NonEmptyList[PackageError], *], List[PackageName], PackageFix] = - Memoize.memoizeDagHashed[Package[PackageName, A, B, C], ReaderT[Either[NonEmptyList[PackageError], *], List[PackageName], PackageFix]] { (p, rec) => - val edeps = ReaderT.ask[Either[NonEmptyList[PackageError], *], List[PackageName]] - .flatMapF { - case nonE@(h :: tail) if nonE.contains(p.name) => - Left(NonEmptyList.of(PackageError.CircularDependency(p.name, NonEmptyList(h, tail)))) - case _ => - val deps = p.imports.traverse(getPackage(_, p)) // the packages p depends on - deps.toEither - } - - edeps - .flatMap { (deps: List[Import[Either[Package.Interface, Package[PackageName, A, B, C]], A]]) => - deps.traverse { i => - i.pack match { - case Right(pack) => - rec(pack) - .local[List[PackageName]](p.name :: _) // add this package into the path of all the deps - .map { p => Import(Package.fix[A, B, C](Right(p)), i.items) } - case Left(iface) => - ReaderT.pure[ - Either[NonEmptyList[PackageError], *], - List[PackageName], - Import[FixPackage[A, B, C], A]](Import(Package.fix[A, B, C](Left(iface)), i.items)) - } + val step: Package[PackageName, A, B, C] => ReaderT[Either[NonEmptyList[ + PackageError + ], *], List[PackageName], PackageFix] = + Memoize.memoizeDagHashed[Package[PackageName, A, B, C], ReaderT[ + Either[NonEmptyList[PackageError], *], + List[PackageName], + PackageFix + ]] { (p, rec) => + val edeps = ReaderT + .ask[Either[NonEmptyList[PackageError], *], List[PackageName]] + .flatMapF { + case nonE @ (h :: tail) if nonE.contains(p.name) => + Left( + NonEmptyList.of( + PackageError.CircularDependency(p.name, NonEmptyList(h, tail)) + ) + ) + case _ => + val deps = p.imports.traverse( + getPackage(_, p) + ) // the packages p depends on + deps.toEither } - .map { imports => - Package(p.name, imports, p.exports, p.program) + + edeps + .flatMap { + (deps: List[Import[ + Either[Package.Interface, Package[PackageName, A, B, C]], + A + ]]) => + deps + .traverse { i => + i.pack match { + case Right(pack) => + rec(pack) + .local[List[PackageName]]( + p.name :: _ + ) // add this package into the path of all the deps + .map { p => + Import(Package.fix[A, B, C](Right(p)), i.items) + } + case Left(iface) => + ReaderT.pure[Either[NonEmptyList[PackageError], *], List[ + PackageName + ], Import[FixPackage[A, B, C], A]]( + Import(Package.fix[A, B, C](Left(iface)), i.items) + ) + } + } + .map { imports => + Package(p.name, imports, p.exports, p.program) + } } - } - } + } type M = SortedMap[PackageName, PackageFix] - val r: ReaderT[Either[NonEmptyList[PackageError], *], List[PackageName], M] = + val r + : ReaderT[Either[NonEmptyList[PackageError], *], List[PackageName], M] = map.toMap.traverse(step) // we start with no imports on @@ -137,41 +188,69 @@ object PackageMap { m.map(PackageMap(_)).toValidated } - /** - * Convenience method to create a PackageMap then resolve it - */ - def resolveAll[A: Show](ps: List[(A, Package.Parsed)], ifs: List[Package.Interface]): Ior[NonEmptyList[PackageError], Resolved] = { + /** Convenience method to create a PackageMap then resolve it + */ + def resolveAll[A: Show]( + ps: List[(A, Package.Parsed)], + ifs: List[Package.Interface] + ): Ior[NonEmptyList[PackageError], Resolved] = { type AP = (A, Package.Parsed) - val (nonUnique, unique): (SortedMap[PackageName, (AP, NonEmptyList[AP])], SortedMap[PackageName, AP]) = + val (nonUnique, unique): ( + SortedMap[PackageName, (AP, NonEmptyList[AP])], + SortedMap[PackageName, AP] + ) = NonEmptyList.fromList(ps) match { case Some(neps) => - CollectionUtils.uniqueByKey(neps)(_._2.name) + CollectionUtils + .uniqueByKey(neps)(_._2.name) .fold( { a => (a.toSortedMap, SortedMap.empty[PackageName, AP]) }, - { b => (SortedMap.empty[PackageName, (AP, NonEmptyList[AP])], b.toSortedMap) }, + { b => + ( + SortedMap.empty[PackageName, (AP, NonEmptyList[AP])], + b.toSortedMap + ) + }, { (a, b) => (a.toSortedMap, b.toSortedMap) } ) case None => - (SortedMap.empty[PackageName, (AP, NonEmptyList[AP])], SortedMap.empty[PackageName, AP]) + ( + SortedMap.empty[PackageName, (AP, NonEmptyList[AP])], + SortedMap.empty[PackageName, AP] + ) } - def toProg(p: Package.Parsed): - (Option[PackageError], - Package[PackageName, Unit, Unit, (List[Statement], ImportMap[PackageName, Unit])]) = { + def toProg(p: Package.Parsed): ( + Option[PackageError], + Package[ + PackageName, + Unit, + Unit, + (List[Statement], ImportMap[PackageName, Unit]) + ] + ) = { val (errs0, imap) = ImportMap.fromImports(p.imports) val errs = - NonEmptyList.fromList(errs0) + NonEmptyList + .fromList(errs0) .map(PackageError.DuplicatedImport) (errs, p.mapProgram((_, imap))) } // we know all the package names are unique here - def foldMap(m: Map[PackageName, (A, Package.Parsed)]): (List[PackageError], PackageMap.ParsedImp) = { + def foldMap( + m: Map[PackageName, (A, Package.Parsed)] + ): (List[PackageError], PackageMap.ParsedImp) = { val initPm = PackageMap - .empty[PackageName, Unit, Unit, (List[Statement], ImportMap[PackageName, Unit])] + .empty[ + PackageName, + Unit, + Unit, + (List[Statement], ImportMap[PackageName, Unit]) + ] m.iterator.foldLeft((List.empty[PackageError], initPm)) { case ((errs, pm), (_, (_, pack))) => @@ -195,9 +274,13 @@ object PackageMap { NonEmptyMap.fromMap(nonUnique) match { case Some(nenu) => val paths = nenu.map { case ((a, _), rest) => - (a.show, rest.map(_._1.show)) + (a.show, rest.map(_._1.show)) } - Ior.left(NonEmptyList.one[PackageError](PackageError.DuplicatedPackageError(paths))) + Ior.left( + NonEmptyList.one[PackageError]( + PackageError.DuplicatedPackageError(paths) + ) + ) case None => Ior.right(()) } @@ -205,10 +288,11 @@ object PackageMap { (nuEr, check, res.toIor).parMapN { (_, _, r) => r } } - /** - * Infer all the types in a resolved PackageMap - */ - def inferAll(ps: Resolved)(implicit cpuEC: Par.EC): Ior[NonEmptyList[PackageError], Inferred] = { + /** Infer all the types in a resolved PackageMap + */ + def inferAll( + ps: Resolved + )(implicit cpuEC: Par.EC): Ior[NonEmptyList[PackageError], Inferred] = { import Par.F @@ -217,48 +301,65 @@ object PackageMap { FixPackage[Unit, Unit, (List[Statement], ImportMap[PackageName, Unit])], Unit, Unit, - (List[Statement], ImportMap[PackageName, Unit])] + (List[Statement], ImportMap[PackageName, Unit]) + ] type FutVal[A] = IorT[F, NonEmptyList[PackageError], A] /* * We memoize this function to avoid recomputing diamond dependencies */ - val infer0: ResolvedU => Par.F[Ior[NonEmptyList[PackageError], (TypeEnv[Kind.Arg], Package.Inferred)]] = - Memoize.memoizeDagFuture[ResolvedU, Ior[NonEmptyList[PackageError], (TypeEnv[Kind.Arg], Package.Inferred)]] { + val infer0: ResolvedU => Par.F[ + Ior[NonEmptyList[PackageError], (TypeEnv[Kind.Arg], Package.Inferred)] + ] = + Memoize.memoizeDagFuture[ResolvedU, Ior[NonEmptyList[ + PackageError + ], (TypeEnv[Kind.Arg], Package.Inferred)]] { // TODO, we ignore importMap here, we only check earlier we don't // have duplicate imports case (Package(nm, imports, exports, (stmt, _)), recurse) => - - def getImport[A, B](packF: Package.Inferred, - exMap: Map[Identifier, NonEmptyList[ExportedName[A]]], - i: ImportedName[B]): Ior[NonEmptyList[PackageError], ImportedName[NonEmptyList[A]]] = + def getImport[A, B]( + packF: Package.Inferred, + exMap: Map[Identifier, NonEmptyList[ExportedName[A]]], + i: ImportedName[B] + ): Ior[NonEmptyList[PackageError], ImportedName[NonEmptyList[A]]] = exMap.get(i.originalName) match { case None => - Ior.left(NonEmptyList.one( - PackageError.UnknownImportName( - nm, - packF.name, - packF - .program - .lets - .iterator - .map { case (n, _, _) => (n: Identifier, ()) }.toMap, - i, - exMap.iterator.flatMap(_._2.toList).toList))) + Ior.left( + NonEmptyList.one( + PackageError.UnknownImportName( + nm, + packF.name, + packF.program.lets.iterator.map { case (n, _, _) => + (n: Identifier, ()) + }.toMap, + i, + exMap.iterator.flatMap(_._2.toList).toList + ) + ) + ) case Some(exps) => val bs = exps.map(_.tag) Ior.right(i.map(_ => bs)) } - def getImportIface[A, B](packF: Package.Interface, - exMap: Map[Identifier, NonEmptyList[ExportedName[A]]], - i: ImportedName[B]): Ior[NonEmptyList[PackageError], ImportedName[NonEmptyList[A]]] = + def getImportIface[A, B]( + packF: Package.Interface, + exMap: Map[Identifier, NonEmptyList[ExportedName[A]]], + i: ImportedName[B] + ): Ior[NonEmptyList[PackageError], ImportedName[NonEmptyList[A]]] = exMap.get(i.originalName) match { case None => - Ior.left(NonEmptyList.one( - PackageError.UnknownImportFromInterface( - nm, packF.name, packF.exports.map(_.name), i, - exMap.iterator.flatMap(_._2.toList).toList))) + Ior.left( + NonEmptyList.one( + PackageError.UnknownImportFromInterface( + nm, + packF.name, + packF.exports.map(_.name), + i, + exMap.iterator.flatMap(_._2.toList).toList + ) + ) + ) case Some(exps) => val bs = exps.map(_.tag) Ior.right(i.map(_ => bs)) @@ -271,8 +372,11 @@ object PackageMap { * type can have the same name as a constructor. After this step, each * distinct object has its own entry in the list */ - type ImpRes = Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]] - def stepImport(imp: Import[Package.Resolved, Unit]): FutVal[ImpRes] = { + type ImpRes = + Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]] + def stepImport( + imp: Import[Package.Resolved, Unit] + ): FutVal[ImpRes] = { val Import(fixpack, items) = imp Package.unfix(fixpack) match { case Right(p) => @@ -308,45 +412,57 @@ object PackageMap { inferImports .flatMap { imps => // run this in a thread - IorT(Par.start(Package.inferBodyUnopt(nm, imps, stmt).map((imps, _)))) + IorT( + Par.start( + Package.inferBodyUnopt(nm, imps, stmt).map((imps, _)) + ) + ) } - inferBody - .flatMap { case (imps, (fte, program@Program(types, lets, _, _))) => + inferBody.flatMap { + case (imps, (fte, program @ Program(types, lets, _, _))) => val ior = ExportedName .buildExports(nm, exports, types, lets) match { - case Validated.Valid(exports) => - // We have a result, which we can continue to check - val res = (fte, Package(nm, imps, exports, program)) - NonEmptyList.fromList(Package.checkValuesHaveExportedTypes(nm, exports)) match { - case None => Ior.right(res) - case Some(errs) => Ior.both(errs, res) - } - case Validated.Invalid(badPackages) => - Ior.left(badPackages.map { n => - PackageError.UnknownExport(n, nm, lets): PackageError - }) - } + case Validated.Valid(exports) => + // We have a result, which we can continue to check + val res = (fte, Package(nm, imps, exports, program)) + NonEmptyList.fromList( + Package.checkValuesHaveExportedTypes(nm, exports) + ) match { + case None => Ior.right(res) + case Some(errs) => Ior.both(errs, res) + } + case Validated.Invalid(badPackages) => + Ior.left(badPackages.map { n => + PackageError.UnknownExport(n, nm, lets): PackageError + }) + } IorT.fromIor(ior) - } - .value - } + }.value + } /* * Since Par.F is starts computation when start is called * we want to start all the computations *then* collect * the result together */ - val infer: ResolvedU => Par.F[Ior[NonEmptyList[PackageError], Package.Inferred]] = + val infer: ResolvedU => Par.F[ + Ior[NonEmptyList[PackageError], Package.Inferred] + ] = infer0.andThen { parF => // As soon as each Par.F is complete, we can start normalizing that one Monad[Par.F].flatMap(parF) { ior => - ior.traverse { - case (fte, pack) => - Par.start { - val optPack = pack.copy(program = TypedExprNormalization.normalizeProgram(pack.name, fte, pack.program)) - Package.discardUnused(optPack) - } + ior.traverse { case (fte, pack) => + Par.start { + val optPack = pack.copy(program = + TypedExprNormalization.normalizeProgram( + pack.name, + fte, + pack.program + ) + ) + Package.discardUnused(optPack) + } } } } @@ -354,54 +470,74 @@ object PackageMap { val fut = ps.toMap.parTraverse(infer.andThen(IorT(_))) // Wait until all the resolution is complete - Par.await(fut.value) + Par + .await(fut.value) .map(PackageMap(_)) } def resolveThenInfer[A: Show]( - ps: List[(A, Package.Parsed)], - ifs: List[Package.Interface])(implicit cpuEC: Par.EC): Ior[NonEmptyList[PackageError], Inferred] = - resolveAll(ps, ifs).flatMap(inferAll) - - def buildSourceMap[F[_]: Foldable, A](parsedFiles: F[((A, LocationMap), Package.Parsed)]): Map[PackageName, (LocationMap, String)] = - parsedFiles.foldLeft(Map.empty[PackageName, (LocationMap, String)]) { case (map, ((path, lm), pack)) => - map.updated(pack.name, (lm, path.toString)) + ps: List[(A, Package.Parsed)], + ifs: List[Package.Interface] + )(implicit cpuEC: Par.EC): Ior[NonEmptyList[PackageError], Inferred] = + resolveAll(ps, ifs).flatMap(inferAll) + + def buildSourceMap[F[_]: Foldable, A]( + parsedFiles: F[((A, LocationMap), Package.Parsed)] + ): Map[PackageName, (LocationMap, String)] = + parsedFiles.foldLeft(Map.empty[PackageName, (LocationMap, String)]) { + case (map, ((path, lm), pack)) => + map.updated(pack.name, (lm, path.toString)) } /** typecheck a list of packages given a list of interface dependencies - * - * @param packs a list of parsed packages, along with a key A to tag the source - * @param ifs the interfaces we are compiling against. If Bosatsu.Predef is not in this list, the default is added - */ + * + * @param packs + * a list of parsed packages, along with a key A to tag the source + * @param ifs + * the interfaces we are compiling against. If Bosatsu.Predef is not in + * this list, the default is added + */ def typeCheckParsed[A: Show]( - packs: NonEmptyList[((A, LocationMap), Package.Parsed)], - ifs: List[Package.Interface], - predefKey: A)(implicit cpuEC: Par.EC): Ior[NonEmptyList[PackageError], PackageMap.Inferred] = { + packs: NonEmptyList[((A, LocationMap), Package.Parsed)], + ifs: List[Package.Interface], + predefKey: A + )(implicit + cpuEC: Par.EC + ): Ior[NonEmptyList[PackageError], PackageMap.Inferred] = { // if we have passed in a use supplied predef, don't use the internal one - val useInternalPredef = !ifs.exists { (p: Package.Interface) => p.name == PackageName.PredefName } + val useInternalPredef = !ifs.exists { (p: Package.Interface) => + p.name == PackageName.PredefName + } // Now we have completed all IO, here we do all the checks we need for correctness val parsed = - if (useInternalPredef) withPredefA[(A, LocationMap)]((predefKey, LocationMap("")), packs.toList) - else withPredefImportsA[(A, LocationMap)](packs.toList) - - PackageMap.resolveThenInfer[A]( - parsed.map { case ((a, _), p) => (a, p) }, - ifs) + if (useInternalPredef) + withPredefA[(A, LocationMap)]( + (predefKey, LocationMap("")), + packs.toList + ) + else withPredefImportsA[(A, LocationMap)](packs.toList) + + PackageMap + .resolveThenInfer[A](parsed.map { case ((a, _), p) => (a, p) }, ifs) } - /** - * Here is the fully compiled Predef - */ + /** Here is the fully compiled Predef + */ val predefCompiled: Package.Inferred = { import DirectEC.directEC - //implicit val showUnit: Show[Unit] = Show.show[Unit](_ => "predefCompiled") - val inferred = PackageMap.resolveThenInfer(((), Package.predefPackage) :: Nil, Nil).strictToValidated + // implicit val showUnit: Show[Unit] = Show.show[Unit](_ => "predefCompiled") + val inferred = PackageMap + .resolveThenInfer(((), Package.predefPackage) :: Nil, Nil) + .strictToValidated inferred match { case Validated.Valid(v) => v.toMap.get(PackageName.PredefName) match { - case None => sys.error("internal error: predef package not found after compilation") + case None => + sys.error( + "internal error: predef package not found after compilation" + ) case Some(inf) => inf } case Validated.Invalid(errs) => @@ -418,13 +554,18 @@ object PackageMap { private val predefImports: Import[PackageName, Unit] = Import(PackageName.PredefName, NonEmptyList.fromList(predefImportList).get) - private def withPredefImportsA[A](ps: List[(A, Package.Parsed)]): List[(A, Package.Parsed)] = + private def withPredefImportsA[A]( + ps: List[(A, Package.Parsed)] + ): List[(A, Package.Parsed)] = ps.map { case (a, p) => (a, p.withImport(predefImports)) } def withPredef(ps: List[Package.Parsed]): List[Package.Parsed] = Package.predefPackage :: ps.map(_.withImport(predefImports)) - def withPredefA[A](predefA: A, ps: List[(A, Package.Parsed)]): List[(A, Package.Parsed)] = + def withPredefA[A]( + predefA: A, + ps: List[(A, Package.Parsed)] + ): List[(A, Package.Parsed)] = (predefA, Package.predefPackage) :: withPredefImportsA(ps) } diff --git a/core/src/main/scala/org/bykn/bosatsu/PackageName.scala b/core/src/main/scala/org/bykn/bosatsu/PackageName.scala index 360ec8e7b..b8c77ab19 100644 --- a/core/src/main/scala/org/bykn/bosatsu/PackageName.scala +++ b/core/src/main/scala/org/bykn/bosatsu/PackageName.scala @@ -14,7 +14,7 @@ case class PackageName(parts: NonEmptyList[String]) { object PackageName { def parts(first: String, rest: String*): PackageName = - PackageName(NonEmptyList.of(first, rest :_*)) + PackageName(NonEmptyList.of(first, rest: _*)) implicit val document: Document[PackageName] = Document.instance[PackageName] { pn => Doc.text(pn.asString) } @@ -28,7 +28,7 @@ object PackageName { def parse(s: String): Option[PackageName] = parser.parse(s) match { case Right(("", pn)) => Some(pn) - case _ => None + case _ => None } implicit val order: Order[PackageName] = @@ -40,4 +40,3 @@ object PackageName { val PredefName: PackageName = PackageName(NonEmptyList.of("Bosatsu", "Predef")) } - diff --git a/core/src/main/scala/org/bykn/bosatsu/Padding.scala b/core/src/main/scala/org/bykn/bosatsu/Padding.scala index e37962c9d..f9286d121 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Padding.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Padding.scala @@ -2,7 +2,7 @@ package org.bykn.bosatsu import cats.Functor import cats.parse.{Parser0 => P0, Parser => P} -import org.typelevel.paiges.{ Doc, Document } +import org.typelevel.paiges.{Doc, Document} import Parser.maybeSpace @@ -20,9 +20,8 @@ object Padding { Doc.line.repeat(padding.lines) + Document[T].document(padding.padded) } - /** - * This allows an empty padding - */ + /** This allows an empty padding + */ def parser[T](p: P[T]): P[Padding[T]] = { val spacing = (maybeSpace.with1.soft ~ Parser.newline).void.rep0 @@ -30,17 +29,14 @@ object Padding { .map { case (vec, t) => Padding(vec.size, t) } } - /** - * Parses a padding of length 1 or more, then p - */ + /** Parses a padding of length 1 or more, then p + */ def parser1[T](p: P0[T]): P[Padding[T]] = ((maybeSpace.with1.soft ~ Parser.newline).void.rep ~ p) .map { case (vec, t) => Padding(vec.size, t) } - /** - * This is parser1 by itself, with the padded value being () - */ + /** This is parser1 by itself, with the padded value being () + */ val nonEmptyParser: P[Padding[Unit]] = parser1(P.unit) } - diff --git a/core/src/main/scala/org/bykn/bosatsu/Parser.scala b/core/src/main/scala/org/bykn/bosatsu/Parser.scala index ff30773a5..27168859a 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Parser.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Parser.scala @@ -8,12 +8,10 @@ import scala.collection.immutable.SortedMap import cats.implicits._ object Parser { - /** - * This is an indentation aware - * parser, the input is the string that - * should be parsed after a new-line to - * continue the current indentation block - */ + + /** This is an indentation aware parser, the input is the string that should + * be parsed after a new-line to continue the current indentation block + */ type Indy[A] = Kleisli[P, String, A] object Indy { @@ -23,9 +21,8 @@ object Parser { def lift[A](p: P[A]): Indy[A] = Kleisli.liftF(p) - /** - * Parse spaces, end of line, then the next indentation - */ + /** Parse spaces, end of line, then the next indentation + */ val toEOLIndent: Indy[Unit] = apply { indent => toEOL1 *> P.string0(indent) @@ -35,10 +32,8 @@ object Parser { def region: Indy[(Region, A)] = toKleisli.mapF(_.region) - /** - * Parse exactly the current indentation - * starting now - */ + /** Parse exactly the current indentation starting now + */ def indentBefore: Indy[A] = apply(indent => P.string0(indent).with1 *> toKleisli.run(indent)) @@ -64,16 +59,13 @@ object Parser { toKleisli(indent) *> that(indent) } - /** - * This optionally allows extra indentation that starts now - */ + /** This optionally allows extra indentation that starts now + */ def maybeMore: Parser.Indy[A] = Indy { indent => // run this one time, not each spaces are parsed val noIndent = toKleisli.run(indent) - val someIndent: P[A] = Parser - .spaces - .string + val someIndent: P[A] = Parser.spaces.string .flatMap { thisIndent => toKleisli.run(indent + thisIndent) } @@ -92,25 +84,35 @@ object Parser { } object Error { - case class ParseFailure(position: Int, locations: LocationMap, expected: NonEmptyList[P.Expectation]) extends Error - - def showExpectations(locations: LocationMap, expected: NonEmptyList[P.Expectation], errColor: LocationMap.Colorize): Doc = { - val errs: SortedMap[Int, NonEmptyList[P.Expectation]] = expected.groupBy(_.offset) + case class ParseFailure( + position: Int, + locations: LocationMap, + expected: NonEmptyList[P.Expectation] + ) extends Error + + def showExpectations( + locations: LocationMap, + expected: NonEmptyList[P.Expectation], + errColor: LocationMap.Colorize + ): Doc = { + val errs: SortedMap[Int, NonEmptyList[P.Expectation]] = + expected.groupBy(_.offset) def show(s: String): Doc = { val q = '\'' if (s.forall(_.isWhitespace)) { val chars = s.length val plural = if (chars == 1) "char" else "chars" - Doc.text(s"$chars whitespace $plural \"") + Doc.intercalate(Doc.empty, + Doc.text(s"$chars whitespace $plural \"") + Doc.intercalate( + Doc.empty, s.map { case '\t' => Doc.text("\\t") case '\n' => Doc.text("\\n") case '\r' => Doc.text("\\r") - case c => Doc.char(c) - }) + Doc.char('"') - } - else { + case c => Doc.char(c) + } + ) + Doc.char('"') + } else { Doc.char(q) + Doc.text(escape(q, s)) + Doc.char(q) } } @@ -122,22 +124,30 @@ object Parser { case one :: Nil => Doc.text("expected ") + show(one) case _ => - Doc.text("expected one of: ") + Doc.intercalate(Doc.line, strs.map(show)).grouped.nested(4) + Doc.text("expected one of: ") + Doc + .intercalate(Doc.line, strs.map(show)) + .grouped + .nested(4) } case P.Expectation.InRange(_, lower, upper) => if (lower == upper) { Doc.text("expected char: ") + show(lower.toString) + } else { + Doc.text("expected char in range: [") + show(lower.toString) + Doc + .text(", ") + show(upper.toString) + Doc.text("]") } - else { - Doc.text("expected char in range: [") + show(lower.toString) + Doc.text(", ") + show(upper.toString) + Doc.text("]") - } - case P.Expectation.StartOfString(_) => Doc.text("expected start of the file") + case P.Expectation.StartOfString(_) => + Doc.text("expected start of the file") case P.Expectation.EndOfString(_, length) => Doc.text(s"expected end of file but $length characters remaining") case P.Expectation.Length(_, expected, actual) => - Doc.text(s"expected $expected more characters but only $actual remaining") + Doc.text( + s"expected $expected more characters but only $actual remaining" + ) case P.Expectation.ExpectedFailureAt(_, matched) => - Doc.text("expected failure but the parser matched: ") + show(matched) + Doc.text("expected failure but the parser matched: ") + show( + matched + ) case P.Expectation.Fail(_) => Doc.text("failed") case P.Expectation.FailWith(_, message) => @@ -146,10 +156,14 @@ object Parser { expToDoc(expect) } - Doc.intercalate(Doc.hardLine, errs.map { case (pos, xs) => - locations.showContext(pos, 2, errColor).get + ( - Doc.hardLine + Doc.intercalate(Doc.comma + Doc.line, xs.toList.map(expToDoc)).grouped).nested(4) - }) + Doc.intercalate( + Doc.hardLine, + errs.map { case (pos, xs) => + locations.showContext(pos, 2, errColor).get + (Doc.hardLine + Doc + .intercalate(Doc.comma + Doc.line, xs.toList.map(expToDoc)) + .grouped).nested(4) + } + ) } } @@ -165,14 +179,16 @@ object Parser { } val identifierCharsP: P0[String] = - P.charIn('_' :: ('a' to 'z').toList ::: ('A' to 'Z').toList ::: ('0' to '9').toList).repAs0 + P.charIn( + '_' :: ('a' to 'z').toList ::: ('A' to 'Z').toList ::: ('0' to '9').toList + ).repAs0 // parse one or more space characters val spaces: P[Unit] = P.charIn(Set(' ', '\t')).rep.void val maybeSpace: P0[Unit] = spaces.?.void /** prefer to parse Right, then Left - */ + */ def either[A, B](pb: P0[B], pa: P0[A]): P0[Either[B, A]] = pa.map(Right(_)).orElse(pb.map(Left(_))) @@ -192,7 +208,9 @@ object Parser { (P.charIn('A' to 'Z') ~ identifierCharsP).string val py2Ident: P[String] = - (P.charIn('_' :: ('A' to 'Z').toList ::: ('a' to 'z').toList) ~ identifierCharsP).string + (P.charIn( + '_' :: ('A' to 'Z').toList ::: ('a' to 'z').toList + ) ~ identifierCharsP).string // parse a keyword and some space or backtrack def keySpace(str: String): P[Unit] = @@ -200,16 +218,13 @@ object Parser { val digit19: P[Char] = P.charIn('1' to '9') val digit09: P[Char] = P.charIn('0' to '9') - /** - * This parser allows _ between any two digits to allow - * literals such as: - * 1_000_000 - * - * It will also parse terrible examples like: - * 1_0_0_0_0_0_0 - * but I think banning things like that shouldn't - * be done by the parser - */ + + /** This parser allows _ between any two digits to allow literals such as: + * 1_000_000 + * + * It will also parse terrible examples like: 1_0_0_0_0_0_0 but I think + * banning things like that shouldn't be done by the parser + */ val integerString: P[String] = { val rest = (P.char('_').?.with1 ~ digit09).rep0 @@ -220,19 +235,13 @@ object Parser { } object JsonNumber { - /** - * from: https://tools.ietf.org/html/rfc4627 - * number = [ minus ] int [ frac ] [ exp ] - * decimal-point = %x2E ; . - * digit1-9 = %x31-39 ; 1-9 - * e = %x65 / %x45 ; e E - * exp = e [ minus / plus ] 1*DIGIT - * frac = decimal-point 1*DIGIT - * int = zero / ( digit1-9 *DIGIT ) - * minus = %x2D ; - - * plus = %x2B ; + - * zero = %x30 ; 0 - */ + + /** from: https://tools.ietf.org/html/rfc4627 number = [ minus ] int [ frac + * ] [ exp ] decimal-point = %x2E ; . digit1-9 = %x31-39 ; 1-9 e = %x65 / + * %x45 ; e E exp = e [ minus / plus ] 1*DIGIT frac = decimal-point 1*DIGIT + * int = zero / ( digit1-9 *DIGIT ) minus = %x2D ; - plus = %x2B ; + zero = + * %x30 ; 0 + */ val digits: P0[Unit] = digit09.rep0.void val digits1: P[Unit] = digit09.rep.void val int: P[Unit] = P.char('0') <+> (digit19 ~ digits).void @@ -243,7 +252,12 @@ object Parser { (P.char('-').?.with1 ~ int ~ frac.? ~ exp.?).string // this gives you the individual parts of a floating point string - case class Parts(negative: Boolean, leftOfPoint: String, floatingPart: String, exp: String) { + case class Parts( + negative: Boolean, + leftOfPoint: String, + floatingPart: String, + exp: String + ) { def asString: String = { val neg = if (negative) "-" else "" s"$neg$leftOfPoint$floatingPart$exp" @@ -268,14 +282,13 @@ object Parser { def nonEmptyListToList[T](p: P0[NonEmptyList[T]]): P0[List[T]] = p.?.map { - case None => Nil + case None => Nil case Some(ne) => ne.toList } - /** - * Parse python-like dicts: delimited by curlies "{" "}" and - * keys separated by colon - */ + /** Parse python-like dicts: delimited by curlies "{" "}" and keys separated + * by colon + */ def dictLikeParser[K, V](pkey: P[K], pvalue: P[V]): P[List[(K, V)]] = { val ws = maybeSpacesAndLines val kv = (pkey ~ ((ws ~ P.char(':') ~ ws).with1 *> pvalue)) @@ -295,11 +308,15 @@ object Parser { def maybeAp(fn: P0[T => T]): P[T] = (item ~ fn.?) .map { - case (a, None) => a + case (a, None) => a case (a, Some(f)) => f(a) } - def nonEmptyListOfWsSep(ws: P0[Unit], sep: P0[Unit], allowTrailing: Boolean): P[NonEmptyList[T]] = { + def nonEmptyListOfWsSep( + ws: P0[Unit], + sep: P0[Unit], + allowTrailing: Boolean + ): P[NonEmptyList[T]] = { val wsSep = (ws.soft ~ sep ~ ws).void val trail = if (allowTrailing) (ws.soft ~ sep).?.void @@ -332,33 +349,29 @@ object Parser { parens(item) def parensLines1Cut: P[NonEmptyList[T]] = - item.nonEmptyListOfWs(maybeSpacesAndLines) - .parensCut + item.nonEmptyListOfWs(maybeSpacesAndLines).parensCut def parensLines0Cut: P[List[T]] = parens(nonEmptyListToList(item.nonEmptyListOfWs(maybeSpacesAndLines))) - /** - * either: a, b, c, .. - * or (a, b, c, ) where we allow newlines: - * return true if we do have parens - */ + + /** either: a, b, c, .. or (a, b, c, ) where we allow newlines: return true + * if we do have parens + */ def itemsMaybeParens: P[(Boolean, NonEmptyList[T])] = { val withP = item.parensLines1Cut.map((true, _)) val noP = item.nonEmptyListOfWs(maybeSpace).map((false, _)) withP.orElse(noP) } - /** - * Parse a python-like tuple or a parens - */ + /** Parse a python-like tuple or a parens + */ def tupleOrParens: P[Either[T, List[T]]] = parens { - tupleOrParens0.? - .map { - case None => Right(Nil) - case Some(Left(t)) => Left(t) - case Some(Right(l)) => Right(l.toList) - } + tupleOrParens0.?.map { + case None => Right(Nil) + case Some(Left(t)) => Left(t) + case Some(Right(l)) => Right(l.toList) + } } def tupleOrParens0: P[Either[T, NonEmptyList[T]]] = { @@ -398,7 +411,8 @@ object Parser { case Right(a) => a case Left(err) => val idx = err.failedAtOffset - sys.error(s"failed to parse: $str: at $idx: (${str.substring(idx)}) with errors: ${err.expected}") + sys.error(s"failed to parse: $str: at $idx: (${str + .substring(idx)}) with errors: ${err.expected}") } sealed abstract class MaybeTupleOrParens[A] @@ -410,7 +424,7 @@ object Parser { def tupleOrParens[A](p: P[A]): P[NotBare[A]] = p.tupleOrParens.map { - case Right(tup) => Tuple(tup) + case Right(tup) => Tuple(tup) case Left(parens) => Parens(parens) } diff --git a/core/src/main/scala/org/bykn/bosatsu/PathGen.scala b/core/src/main/scala/org/bykn/bosatsu/PathGen.scala index d888c959c..7fd0c3127 100644 --- a/core/src/main/scala/org/bykn/bosatsu/PathGen.scala +++ b/core/src/main/scala/org/bykn/bosatsu/PathGen.scala @@ -12,14 +12,20 @@ object PathGen { def read(implicit m: Monad[IO]): IO[List[Path]] = m.pure(path :: Nil) } - final case class ChildrenOfDir[IO[_], Path](dir: Path, select: Path => Boolean, recurse: Boolean, unfold: Path => IO[Option[IO[List[Path]]]]) extends PathGen[IO, Path] { + final case class ChildrenOfDir[IO[_], Path]( + dir: Path, + select: Path => Boolean, + recurse: Boolean, + unfold: Path => IO[Option[IO[List[Path]]]] + ) extends PathGen[IO, Path] { def read(implicit m: Monad[IO]): IO[List[Path]] = { val pureEmpty: IO[List[Path]] = m.pure(Nil) lazy val rec: List[Path] => IO[List[Path]] = - if (recurse) { (children: List[Path]) => children.traverse(step).map(_.flatten) } - else { (_: List[Path]) => pureEmpty } + if (recurse) { (children: List[Path]) => + children.traverse(step).map(_.flatten) + } else { (_: List[Path]) => pureEmpty } def step(path: Path): IO[List[Path]] = unfold(path).flatMap { @@ -35,7 +41,8 @@ object PathGen { step(dir) } } - final case class Combine[IO[_], Path](gens: List[PathGen[IO, Path]]) extends PathGen[IO, Path] { + final case class Combine[IO[_], Path](gens: List[PathGen[IO, Path]]) + extends PathGen[IO, Path] { def read(implicit m: Monad[IO]): IO[List[Path]] = gens.traverse(_.read).map(_.flatten) } @@ -45,12 +52,12 @@ object PathGen { val empty: PathGen[IO, Path] = Combine(Nil) def combine(a: PathGen[IO, Path], b: PathGen[IO, Path]) = (a, b) match { - case (Combine(Nil), b) => b - case (a, Combine(Nil)) => a + case (Combine(Nil), b) => b + case (a, Combine(Nil)) => a case (Combine(as), Combine(bs)) => Combine(as ::: bs) - case (Combine(as), b) => Combine(as :+ b) - case (a, Combine(bs)) => Combine(a :: bs) - case (a, b) => Combine(a :: b :: Nil) + case (Combine(as), b) => Combine(as :+ b) + case (a, Combine(bs)) => Combine(a :: bs) + case (a, b) => Combine(a :: b :: Nil) } } } diff --git a/core/src/main/scala/org/bykn/bosatsu/Pattern.scala b/core/src/main/scala/org/bykn/bosatsu/Pattern.scala index e79c11175..7cac7b77d 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Pattern.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Pattern.scala @@ -3,10 +3,10 @@ package org.bykn.bosatsu import cats.{Applicative, Foldable} import cats.data.NonEmptyList import cats.parse.{Parser0 => P0, Parser => P} -import org.typelevel.paiges.{ Doc, Document } +import org.typelevel.paiges.{Doc, Document} import org.bykn.bosatsu.pattern.{NamedSeqPattern, SeqPattern, SeqPart} -import Parser.{ Combinators, maybeSpace, MaybeTupleOrParens } +import Parser.{Combinators, maybeSpace, MaybeTupleOrParens} import cats.implicits._ import Identifier.{Bindable, Constructor} @@ -20,16 +20,20 @@ sealed abstract class Pattern[+N, +T] { def mapType[U](fn: T => U): Pattern[N, U] = (new Pattern.InvariantPattern(this)).traverseType[cats.Id, U](fn) - /** - * List all the names that are bound in Vars inside this pattern - * in the left to right order they are encountered, without any duplication - */ + /** List all the names that are bound in Vars inside this pattern in the left + * to right order they are encountered, without any duplication + */ lazy val names: List[Bindable] = { @annotation.tailrec - def loop(stack: List[Pattern[N, T]], seen: Set[Bindable], acc: List[Bindable]): List[Bindable] = + def loop( + stack: List[Pattern[N, T]], + seen: Set[Bindable], + acc: List[Bindable] + ): List[Bindable] = stack match { case Nil => acc.reverse - case (Pattern.WildCard | Pattern.Literal(_)) :: tail => loop(tail, seen, acc) + case (Pattern.WildCard | Pattern.Literal(_)) :: tail => + loop(tail, seen, acc) case Pattern.Var(v) :: tail => if (seen(v)) loop(tail, seen, acc) else loop(tail, seen + v, v :: acc) @@ -37,11 +41,17 @@ sealed abstract class Pattern[+N, +T] { if (seen(v)) loop(p :: tail, seen, acc) else loop(p :: tail, seen + v, v :: acc) case Pattern.StrPat(items) :: tail => - val names = items.collect { case Pattern.StrPart.NamedStr(n) => n }.filterNot(seen) + val names = items + .collect { case Pattern.StrPart.NamedStr(n) => n } + .filterNot(seen) loop(tail, seen ++ names, names reverse_::: acc) case Pattern.ListPat(items) :: tail => - val globs = items.collect { case Pattern.ListPart.NamedList(glob) => glob }.filterNot(seen) - val next = items.collect { case Pattern.ListPart.Item(inner) => inner } + val globs = items + .collect { case Pattern.ListPart.NamedList(glob) => glob } + .filterNot(seen) + val next = items.collect { case Pattern.ListPart.Item(inner) => + inner + } loop(next ::: tail, seen ++ globs, globs reverse_::: acc) case Pattern.Annotation(p, _) :: tail => loop(p :: tail, seen, acc) case Pattern.PositionalStruct(_, params) :: tail => @@ -53,68 +63,88 @@ sealed abstract class Pattern[+N, +T] { loop(this :: Nil, Set.empty, Nil) } - /** - * What are the names that will be bound to the entire pattern, - * Bar(x) as foo would return List(foo) - * foo as bar as baz would return List(baz, bar, foo) - * Bar(x) would return Nil - */ + /** What are the names that will be bound to the entire pattern, Bar(x) as foo + * would return List(foo) foo as bar as baz would return List(baz, bar, foo) + * Bar(x) would return Nil + */ lazy val topNames: List[Bindable] = { this match { - case Pattern.Var(v) => v :: Nil + case Pattern.Var(v) => v :: Nil case Pattern.Named(v, p) => (v :: p.topNames).distinct case Pattern.ListPat(Pattern.ListPart.NamedList(n) :: Nil) => n :: Nil - case Pattern.Annotation(p, _) => p.topNames - case Pattern.Union(h, t) => + case Pattern.Annotation(p, _) => p.topNames + case Pattern.Union(h, t) => // the intersection of all top level names // is okay val pats = h :: t.toList val patIntr = pats.map(_.topNames.toSet).reduce(_ & _) // put them in the same order as written: pats.flatMap(_.topNames).iterator.filter(patIntr).toList.distinct - case Pattern.ListPat(_) | Pattern.WildCard | Pattern.Literal(_) | Pattern.StrPat(_) | Pattern.PositionalStruct(_, _) => Nil + case Pattern.ListPat(_) | Pattern.WildCard | Pattern.Literal(_) | + Pattern.StrPat(_) | Pattern.PositionalStruct(_, _) => + Nil } } - /** - * List all the names that strictly smaller than anything that would match this pattern - * e.g. a top level var, would not be returned - */ + /** List all the names that strictly smaller than anything that would match + * this pattern e.g. a top level var, would not be returned + */ def substructures: List[Bindable] = { - def cheat(stack: List[(Pattern[N, T], Boolean)], seen: Set[Bindable], acc: List[Bindable]): List[Bindable] = + def cheat( + stack: List[(Pattern[N, T], Boolean)], + seen: Set[Bindable], + acc: List[Bindable] + ): List[Bindable] = loop(stack, seen, acc) import Pattern.{ListPart, StrPart} @annotation.tailrec - def loop(stack: List[(Pattern[N, T], Boolean)], seen: Set[Bindable], acc: List[Bindable]): List[Bindable] = + def loop( + stack: List[(Pattern[N, T], Boolean)], + seen: Set[Bindable], + acc: List[Bindable] + ): List[Bindable] = stack match { case Nil => acc.reverse - case ((Pattern.WildCard, _) | (Pattern.Literal(_), _)) :: tail => loop(tail, seen, acc) + case ((Pattern.WildCard, _) | (Pattern.Literal(_), _)) :: tail => + loop(tail, seen, acc) case (Pattern.Var(v), isTop) :: tail => if (seen(v) || isTop) loop(tail, seen, acc) else loop(tail, seen + v, v :: acc) case (Pattern.Named(v, p), isTop) :: tail => if (seen(v) || isTop) loop((p, isTop) :: tail, seen, acc) else loop((p, isTop) :: tail, seen + v, v :: acc) - case (Pattern.StrPat(NonEmptyList(StrPart.NamedStr(_), Nil)), true) :: tail => - // this is a total match at the top level, not a substructure - loop(tail, seen, acc) + case ( + Pattern.StrPat(NonEmptyList(StrPart.NamedStr(_), Nil)), + true + ) :: tail => + // this is a total match at the top level, not a substructure + loop(tail, seen, acc) case (Pattern.StrPat(items), _) :: tail => - val globs = items.collect { case StrPart.NamedStr(glob) => glob }.filterNot(seen) + val globs = items + .collect { case StrPart.NamedStr(glob) => glob } + .filterNot(seen) loop(tail, seen ++ globs, globs reverse_::: acc) case (Pattern.ListPat(ListPart.NamedList(_) :: Nil), true) :: tail => - // this is a total match at the top level, not a substructure - loop(tail, seen, acc) + // this is a total match at the top level, not a substructure + loop(tail, seen, acc) case (Pattern.ListPat(items), _) :: tail => - val globs = items.collect { case ListPart.NamedList(glob) => glob }.filterNot(seen) - val next = items.collect { case ListPart.Item(inner) => (inner, false) } + val globs = items + .collect { case ListPart.NamedList(glob) => glob } + .filterNot(seen) + val next = items.collect { case ListPart.Item(inner) => + (inner, false) + } loop(next ::: tail, seen ++ globs, globs reverse_::: acc) - case (Pattern.Annotation(p, _), isTop) :: tail => loop((p, isTop) :: tail, seen, acc) + case (Pattern.Annotation(p, _), isTop) :: tail => + loop((p, isTop) :: tail, seen, acc) case (Pattern.PositionalStruct(_, params), _) :: tail => loop(params.map((_, false)) ::: tail, seen, acc) case (Pattern.Union(h, t), isTop) :: tail => - val all = (h :: t.toList).map { p => cheat((p, isTop) :: tail, seen, acc) } + val all = (h :: t.toList).map { p => + cheat((p, isTop) :: tail, seen, acc) + } // we need to be substructual on all: val intr = all.map(_.toSet).reduce(_.intersect(_)) all.flatMap(_.filter(intr)).distinct @@ -123,20 +153,18 @@ sealed abstract class Pattern[+N, +T] { loop((this, true) :: Nil, Set.empty, Nil) } - /** - * Return the pattern with all the binding names removed - */ + /** Return the pattern with all the binding names removed + */ def unbind: Pattern[N, T] = filterVars(Set.empty) - /** - * replace all Var names with Wildcard that are not - * satifying the keep predicate - */ + /** replace all Var names with Wildcard that are not satifying the keep + * predicate + */ def filterVars(keep: Bindable => Boolean): Pattern[N, T] = this match { case Pattern.WildCard | Pattern.Literal(_) => this - case p@Pattern.Var(v) => + case p @ Pattern.Var(v) => if (keep(v)) p else Pattern.WildCard case Pattern.Named(v, p) => val inner = p.filterVars(keep) @@ -144,15 +172,15 @@ sealed abstract class Pattern[+N, +T] { else inner case Pattern.StrPat(items) => Pattern.StrPat(items.map { - case wl@(Pattern.StrPart.WildStr | Pattern.StrPart.LitStr(_)) => wl - case in@Pattern.StrPart.NamedStr(n) => + case wl @ (Pattern.StrPart.WildStr | Pattern.StrPart.LitStr(_)) => wl + case in @ Pattern.StrPart.NamedStr(n) => if (keep(n)) in else Pattern.StrPart.WildStr }) case Pattern.ListPat(items) => Pattern.ListPat(items.map { case Pattern.ListPart.WildList => Pattern.ListPart.WildList - case in@Pattern.ListPart.NamedList(n) => + case in @ Pattern.ListPart.NamedList(n) => if (keep(n)) in else Pattern.ListPart.WildList case Pattern.ListPart.Item(p) => @@ -166,23 +194,23 @@ sealed abstract class Pattern[+N, +T] { Pattern.Union(h.filterVars(keep), t.map(_.filterVars(keep))) } - /** - * a collision happens when the same binding happens twice - * not separated by a union - */ + /** a collision happens when the same binding happens twice not separated by a + * union + */ def collisionBinds: List[Bindable] = { def loop(pat: Pattern[N, T]): (Set[Bindable], List[Bindable]) = pat match { case Pattern.WildCard | Pattern.Literal(_) => (Set.empty, Nil) - case Pattern.Var(v) => (Set(v), Nil) + case Pattern.Var(v) => (Set(v), Nil) case Pattern.Named(v, p) => val (s1, l1) = loop(p) if (s1(v)) (s1, v :: l1) else (s1 + v, l1) case Pattern.StrPat(items) => items.foldLeft((Set.empty[Bindable], List.empty[Bindable])) { - case (res, Pattern.StrPart.WildStr | Pattern.StrPart.LitStr(_)) => res + case (res, Pattern.StrPart.WildStr | Pattern.StrPart.LitStr(_)) => + res case ((s1, l1), Pattern.StrPart.NamedStr(v)) => if (s1(v)) (s1, v :: l1) else (s1 + v, l1) @@ -214,32 +242,30 @@ sealed abstract class Pattern[+N, +T] { loop(this)._2.distinct.sorted } - /** - * @return the type if we can directly see it + /** @return + * the type if we can directly see it */ def simpleTypeOf: Option[T] = this match { - case Pattern.Named(_, p) => p.simpleTypeOf + case Pattern.Named(_, p) => p.simpleTypeOf case Pattern.Annotation(_, t) => Some(t) case Pattern.Union(_, _) | Pattern.ListPat(_) | Pattern.Literal(_) | - Pattern.WildCard | Pattern.Var(_) | Pattern.StrPat(_) | - Pattern.PositionalStruct(_, _) => None + Pattern.WildCard | Pattern.Var(_) | Pattern.StrPat(_) | + Pattern.PositionalStruct(_, _) => + None } } object Pattern { - /** - * Represents the different patterns that are all for structs - * (2, 3) - * Foo(2, 3) - * etc... - */ + /** Represents the different patterns that are all for structs (2, 3) Foo(2, + * 3) etc... + */ sealed abstract class StructKind { def namedStyle: Option[StructKind.Style] = this match { - case StructKind.Tuple => None - case StructKind.Named(_, style) => Some(style) + case StructKind.Tuple => None + case StructKind.Named(_, style) => Some(style) case StructKind.NamedPartial(_, style) => Some(style) } } @@ -267,7 +293,8 @@ object Pattern { // Represents a complete tuple-like pattern Foo(a, b) final case class Named(name: Constructor, style: Style) extends NamedKind // Represents a partial tuple-like pattern Foo(a, ...) - final case class NamedPartial(name: Constructor, style: Style) extends NamedKind + final case class NamedPartial(name: Constructor, style: Style) + extends NamedKind } sealed abstract class StrPart @@ -285,14 +312,14 @@ object Pattern { def document(q: Char): Document[StrPart] = Document.instance { case WildStr => wildDoc - case NamedStr(b) => prefix + Document[Bindable].document(b) + Doc.char('}') + case NamedStr(b) => + prefix + Document[Bindable].document(b) + Doc.char('}') case LitStr(s) => Doc.text(StringUtil.escape(q, s)) } } - /** - * represents items in a list pattern - */ + /** represents items in a list pattern + */ sealed abstract class ListPart[+A] { def map[B](fn: A => B): ListPart[B] } @@ -307,30 +334,29 @@ object Pattern { } } - /** - * This will match any list without any binding - */ + /** This will match any list without any binding + */ val AnyList: Pattern[Nothing, Nothing] = Pattern.ListPat(ListPart.WildList :: Nil) type Parsed = Pattern[StructKind, TypeRef] - /** - * Flatten a pattern out such that there are no top-level - * unions - */ + /** Flatten a pattern out such that there are no top-level unions + */ def flatten[N, T](p: Pattern[N, T]): NonEmptyList[Pattern[N, T]] = p match { case Union(h, t) => NonEmptyList(h, t.toList).flatMap(flatten(_)) - case nonU => NonEmptyList.one(nonU) + case nonU => NonEmptyList.one(nonU) } - /** - * Create a normalized pattern, which doesn't have nested top level unions - */ - def union[N, T](head: Pattern[N, T], tail: List[Pattern[N, T]]): Pattern[N, T] = { + /** Create a normalized pattern, which doesn't have nested top level unions + */ + def union[N, T]( + head: Pattern[N, T], + tail: List[Pattern[N, T]] + ): Pattern[N, T] = { NonEmptyList(head, tail).flatMap(flatten(_)) match { - case NonEmptyList(h, Nil) => h + case NonEmptyList(h, Nil) => h case NonEmptyList(h0, h1 :: tail) => Union(h0, NonEmptyList(h1, tail)) } } @@ -340,23 +366,27 @@ object Pattern { traversePattern[F, N, T1]( { (n, args) => args.map(PositionalStruct(n, _)) }, fn, - { parts => parts.map(ListPat(_)) }) + { parts => parts.map(ListPat(_)) } + ) - def mapStruct[N1](parts: (N, List[Pattern[N1, T]]) => Pattern[N1, T]): Pattern[N1, T] = + def mapStruct[N1]( + parts: (N, List[Pattern[N1, T]]) => Pattern[N1, T] + ): Pattern[N1, T] = traversePattern[cats.Id, N1, T](parts, t => t, ListPat(_)) def traversePattern[F[_]: Applicative, N1, T1]( - parts: (N, F[List[Pattern[N1, T1]]]) => F[Pattern[N1, T1]], - tpeFn: T => F[T1], - listFn: F[List[ListPart[Pattern[N1, T1]]]] => F[Pattern[N1, T1]]): F[Pattern[N1, T1]] = { + parts: (N, F[List[Pattern[N1, T1]]]) => F[Pattern[N1, T1]], + tpeFn: T => F[T1], + listFn: F[List[ListPart[Pattern[N1, T1]]]] => F[Pattern[N1, T1]] + ): F[Pattern[N1, T1]] = { lazy val pwild: F[Pattern[N1, T1]] = Applicative[F].pure(Pattern.WildCard) def go(pat: Pattern[N, T]): F[Pattern[N1, T1]] = pat match { - case Pattern.WildCard => pwild + case Pattern.WildCard => pwild case Pattern.Literal(lit) => Applicative[F].pure(Pattern.Literal(lit)) - case Pattern.Var(v) => Applicative[F].pure(Pattern.Var(v)) - case Pattern.StrPat(s) => Applicative[F].pure(Pattern.StrPat(s)) + case Pattern.Var(v) => Applicative[F].pure(Pattern.Var(v)) + case Pattern.StrPat(s) => Applicative[F].pure(Pattern.StrPat(s)) case Pattern.Named(v, p) => go(p).map(Pattern.Named(v, _)) case Pattern.ListPat(items) => @@ -383,20 +413,23 @@ object Pattern { } } - implicit class FoldablePattern[F[_], N, T](private val pats: F[Pattern[N, T]]) extends AnyVal { - def patternNames(implicit F: Foldable[F]): List[Bindable] = F.toList(pats).flatMap(_.names) + implicit class FoldablePattern[F[_], N, T](private val pats: F[Pattern[N, T]]) + extends AnyVal { + def patternNames(implicit F: Foldable[F]): List[Bindable] = + F.toList(pats).flatMap(_.names) } case object WildCard extends Pattern[Nothing, Nothing] case class Literal(toLit: Lit) extends Pattern[Nothing, Nothing] case class Var(name: Bindable) extends Pattern[Nothing, Nothing] - case class StrPat(parts: NonEmptyList[StrPart]) extends Pattern[Nothing, Nothing] { + case class StrPat(parts: NonEmptyList[StrPart]) + extends Pattern[Nothing, Nothing] { def isEmpty: Boolean = this == StrPat.Empty lazy val isTotal: Boolean = !parts.exists { case Pattern.StrPart.LitStr(_) => true - case _ => false + case _ => false } lazy val toNamedSeqPattern: NamedSeqPattern[Char] = @@ -413,28 +446,37 @@ object Pattern { isTotal || matcher(str).isDefined } - /** - * Patterns like Some(_) as foo - * as binds tighter than |, so use ( ) with groups you want to bind - */ - case class Named[N, T](name: Bindable, pat: Pattern[N, T]) extends Pattern[N, T] - case class ListPat[N, T](parts: List[ListPart[Pattern[N, T]]]) extends Pattern[N, T] { + /** Patterns like Some(_) as foo as binds tighter than |, so use ( ) with + * groups you want to bind + */ + case class Named[N, T](name: Bindable, pat: Pattern[N, T]) + extends Pattern[N, T] + case class ListPat[N, T](parts: List[ListPart[Pattern[N, T]]]) + extends Pattern[N, T] { lazy val toNamedSeqPattern: NamedSeqPattern[Pattern[N, T]] = ListPat.toNamedSeqPattern(this) lazy val toSeqPattern: SeqPattern[Pattern[N, T]] = toNamedSeqPattern.unname - def toPositionalStruct(empty: N, cons: N): Either[(ListPart.Glob, NonEmptyList[ListPart[Pattern[N, T]]]), Pattern[N, T]] = { - def loop(parts: List[ListPart[Pattern[N, T]]]): Either[(ListPart.Glob, NonEmptyList[ListPart[Pattern[N, T]]]), Pattern[N, T]] = + def toPositionalStruct(empty: N, cons: N): Either[ + (ListPart.Glob, NonEmptyList[ListPart[Pattern[N, T]]]), + Pattern[N, T] + ] = { + def loop( + parts: List[ListPart[Pattern[N, T]]] + ): Either[(ListPart.Glob, NonEmptyList[ListPart[Pattern[N, T]]]), Pattern[ + N, + T + ]] = parts match { - case Nil => Right(PositionalStruct(empty, Nil)) + case Nil => Right(PositionalStruct(empty, Nil)) case ListPart.WildList :: Nil => Right(WildCard) case ListPart.NamedList(glob) :: Nil => Right(Var(glob)) - case ListPart.Item(p) :: tail => + case ListPart.Item(p) :: tail => // we can always make some progress here val tailPat = loop(tail).toOption.getOrElse(ListPat(tail)) Right(PositionalStruct(cons, List(p, tailPat))) - case (l@ListPart.WildList) :: (r@ListPart.Item(WildCard)) :: t => + case (l @ ListPart.WildList) :: (r @ ListPart.Item(WildCard)) :: t => // we can switch *_, _ with _, *_ loop(r :: l :: t) case (glob: ListPart.Glob) :: h1 :: t => @@ -445,9 +487,12 @@ object Pattern { loop(parts) } } - case class Annotation[N, T](pattern: Pattern[N, T], tpe: T) extends Pattern[N, T] - case class PositionalStruct[N, T](name: N, params: List[Pattern[N, T]]) extends Pattern[N, T] - case class Union[N, T](head: Pattern[N, T], rest: NonEmptyList[Pattern[N, T]]) extends Pattern[N, T] + case class Annotation[N, T](pattern: Pattern[N, T], tpe: T) + extends Pattern[N, T] + case class PositionalStruct[N, T](name: N, params: List[Pattern[N, T]]) + extends Pattern[N, T] + case class Union[N, T](head: Pattern[N, T], rest: NonEmptyList[Pattern[N, T]]) + extends Pattern[N, T] object ListPat { val Wild: ListPat[Nothing, Nothing] = @@ -456,7 +501,10 @@ object Pattern { def fromSeqPattern[N, T](sp: SeqPattern[Pattern[N, T]]): ListPat[N, T] = { @annotation.tailrec - def loop(ps: List[SeqPart[Pattern[N, T]]], front: List[ListPart[Pattern[N, T]]]): List[ListPart[Pattern[N, T]]] = + def loop( + ps: List[SeqPart[Pattern[N, T]]], + front: List[ListPart[Pattern[N, T]]] + ): List[ListPart[Pattern[N, T]]] = ps match { case Nil => front.reverse case SeqPart.Lit(p) :: tail => @@ -473,19 +521,25 @@ object Pattern { ListPat(loop(sp.toList, Nil)) } - def toNamedSeqPattern[N, T](lp: ListPat[N, T]): NamedSeqPattern[Pattern[N, T]] = { - def partToNsp(lp: ListPart[Pattern[N, T]]): NamedSeqPattern[Pattern[N, T]] = + def toNamedSeqPattern[N, T]( + lp: ListPat[N, T] + ): NamedSeqPattern[Pattern[N, T]] = { + def partToNsp( + lp: ListPart[Pattern[N, T]] + ): NamedSeqPattern[Pattern[N, T]] = lp match { case ListPart.Item(WildCard) => NamedSeqPattern.Any - case ListPart.Item(p) => NamedSeqPattern.fromLit(p) - case ListPart.WildList => NamedSeqPattern.Wild + case ListPart.Item(p) => NamedSeqPattern.fromLit(p) + case ListPart.WildList => NamedSeqPattern.Wild case ListPart.NamedList(n) => NamedSeqPattern.Bind(n.sourceCodeRepr, NamedSeqPattern.Wild) } - def loop(lp: List[ListPart[Pattern[N, T]]]): NamedSeqPattern[Pattern[N, T]] = + def loop( + lp: List[ListPart[Pattern[N, T]]] + ): NamedSeqPattern[Pattern[N, T]] = lp match { - case Nil => NamedSeqPattern.NEmpty + case Nil => NamedSeqPattern.NEmpty case h :: Nil => partToNsp(h) case h :: t => NamedSeqPattern.NCat(partToNsp(h), loop(t)) @@ -504,7 +558,10 @@ object Pattern { if (rev.isEmpty) Nil else StrPart.LitStr(rev.reverse.mkString) :: Nil - def loop(ps: List[SeqPart[Char]], front: List[Char]): NonEmptyList[StrPart] = + def loop( + ps: List[SeqPart[Char]], + front: List[Char] + ): NonEmptyList[StrPart] = ps match { case Nil => NonEmptyList.fromList(lit(front)).getOrElse(Empty.parts) case SeqPart.Lit(c) :: tail => @@ -522,7 +579,7 @@ object Pattern { val tailRes = loop(tail, Nil).prepend(StrPart.WildStr) NonEmptyList.fromList(lit(front)) match { - case None => tailRes + case None => tailRes case Some(h) => h ::: tailRes } } @@ -535,7 +592,7 @@ object Pattern { s match { case StrPart.NamedStr(n) => NamedSeqPattern.Bind(n.sourceCodeRepr, NamedSeqPattern.Wild) - case StrPart.WildStr => NamedSeqPattern.Wild + case StrPart.WildStr => NamedSeqPattern.Wild case StrPart.LitStr(s) => // reverse so we can build right associated s.toList.reverse match { @@ -561,18 +618,13 @@ object Pattern { StrPat(NonEmptyList.one(StrPart.LitStr(s))) } - /** - * If this pattern is: - * x - * (x: T) - * unnamed as x - * x | x | x - * then it is "SinglyNamed" - */ + /** If this pattern is: x (x: T) unnamed as x x | x | x then it is + * "SinglyNamed" + */ object SinglyNamed { def unapply[N, T](p: Pattern[N, T]): Option[Bindable] = p match { - case Var(b) => Some(b) + case Var(b) => Some(b) case Annotation(SinglyNamed(b), _) => Some(b) case Named(b, inner) => if (inner.names.isEmpty) Some(b) @@ -585,7 +637,8 @@ object Pattern { } } - implicit def patternOrdering[N: Ordering, T: Ordering]: Ordering[Pattern[N, T]] = + implicit def patternOrdering[N: Ordering, T: Ordering] + : Ordering[Pattern[N, T]] = new Ordering[Pattern[N, T]] { val ordN = implicitly[Ordering[N]] val ordT = implicitly[Ordering[T]] @@ -595,14 +648,14 @@ object Pattern { new Ordering[ListPart[A]] { def compare(a: ListPart[A], b: ListPart[A]) = (a, b) match { - case (ListPart.WildList, ListPart.WildList) => 0 - case (ListPart.WildList, _) => -1 + case (ListPart.WildList, ListPart.WildList) => 0 + case (ListPart.WildList, _) => -1 case (ListPart.NamedList(_), ListPart.WildList) => 1 case (ListPart.NamedList(a), ListPart.NamedList(b)) => ordBin.compare(a, b) case (ListPart.NamedList(_), ListPart.Item(_)) => -1 case (ListPart.Item(a), ListPart.Item(b)) => ordA.compare(a, b) - case (ListPart.Item(_), _) => 1 + case (ListPart.Item(_), _) => 1 } } @@ -613,13 +666,13 @@ object Pattern { def compare(a: StrPart, b: StrPart) = (a, b) match { - case (WildStr, WildStr) => 0 - case (WildStr, _) => -1 - case (LitStr(_), WildStr) => 1 - case (LitStr(sa), LitStr(sb)) => sa.compareTo(sb) - case (LitStr(_), NamedStr(_)) => -1 + case (WildStr, WildStr) => 0 + case (WildStr, _) => -1 + case (LitStr(_), WildStr) => 1 + case (LitStr(sa), LitStr(sb)) => sa.compareTo(sb) + case (LitStr(_), NamedStr(_)) => -1 case (NamedStr(na), NamedStr(nb)) => ordBin.compare(na, nb) - case (NamedStr(_), _) => 1 + case (NamedStr(_), _) => 1 } } val strOrd = ListOrdering.onType(ordStrPart) @@ -628,30 +681,34 @@ object Pattern { def compare(a: Pattern[N, T], b: Pattern[N, T]): Int = (a, b) match { - case (WildCard, WildCard) => 0 - case (WildCard, _) => -1 - case (Literal(_), WildCard) => 1 - case (Literal(a), Literal(b)) => Lit.litOrdering.compare(a, b) - case (Literal(_), _) => -1 + case (WildCard, WildCard) => 0 + case (WildCard, _) => -1 + case (Literal(_), WildCard) => 1 + case (Literal(a), Literal(b)) => Lit.litOrdering.compare(a, b) + case (Literal(_), _) => -1 case (Var(_), WildCard | Literal(_)) => 1 - case (Var(a), Var(b)) => compIdent.compare(a, b) - case (Var(_), _) => -1 + case (Var(a), Var(b)) => compIdent.compare(a, b) + case (Var(_), _) => -1 case (Named(_, _), WildCard | Literal(_) | Var(_)) => 1 case (Named(n1, p1), Named(n2, p2)) => val c = compIdent.compare(n1, n2) if (c == 0) compare(p1, p2) else c - case (Named(_, _), _) => -1 + case (Named(_, _), _) => -1 case (StrPat(_), WildCard | Literal(_) | Var(_) | Named(_, _)) => 1 case (StrPat(as), StrPat(bs)) => strOrd.compare(as.toList, bs.toList) - case (StrPat(_), _) => -1 - case (ListPat(_), WildCard | Literal(_) | Var(_) | Named(_, _) | StrPat(_)) => 1 + case (StrPat(_), _) => -1 + case ( + ListPat(_), + WildCard | Literal(_) | Var(_) | Named(_, _) | StrPat(_) + ) => + 1 case (ListPat(as), ListPat(bs)) => listE.compare(as, bs) - case (ListPat(_), _) => -1 + case (ListPat(_), _) => -1 case (Annotation(_, _), PositionalStruct(_, _) | Union(_, _)) => -1 case (Annotation(a0, t0), Annotation(a1, t1)) => val c = compare(a0, a1) if (c == 0) ordT.compare(t0, t1) else c - case (Annotation(_, _), _) => 1 + case (Annotation(_, _), _) => 1 case (PositionalStruct(_, _), Union(_, _)) => -1 case (PositionalStruct(n0, a0), PositionalStruct(n1, a1)) => val c = ordN.compare(n0, n1) @@ -665,32 +722,39 @@ object Pattern { implicit def document[T: Document]: Document[Pattern[StructKind, T]] = Document.instance[Pattern[StructKind, T]] { - case WildCard => Doc.char('_') - case Literal(lit) => Document[Lit].document(lit) - case Var(n) => Document[Identifier].document(n) - case Named(n, u@Union(_, _)) => + case WildCard => Doc.char('_') + case Literal(lit) => Document[Lit].document(lit) + case Var(n) => Document[Identifier].document(n) + case Named(n, u @ Union(_, _)) => // union is also an operator, so we need to use parens to explicitly bind | more tightly // than the @ on the left. - Doc.char('(') + document.document(u) + Doc.char(')') + Doc.text(" as ") + Document[Identifier].document(n) + Doc.char('(') + document.document(u) + Doc.char(')') + Doc.text( + " as " + ) + Document[Identifier].document(n) case Named(n, p) => - document.document(p) + Doc.text(" as ") + Document[Identifier].document(n) + document.document(p) + Doc.text(" as ") + Document[Identifier].document( + n + ) case StrPat(items) => // prefer ' if possible, else use " val useDouble = items.exists { case StrPart.LitStr(str) => str.contains('\'') && !str.contains('"') - case _ => false + case _ => false } val q = if (useDouble) '"' else '\'' val sd = StrPart.document(q) val inner = Doc.intercalate(Doc.empty, items.toList.map(sd.document(_))) Doc.char(q) + inner + Doc.char(q) case ListPat(items) => - Doc.char('[') + Doc.intercalate(Doc.text(", "), + Doc.char('[') + Doc.intercalate( + Doc.text(", "), items.map { case ListPart.WildList => Doc.text("*_") - case ListPart.NamedList(glob) => Doc.char('*') + Document[Identifier].document(glob) + case ListPart.NamedList(glob) => + Doc.char('*') + Document[Identifier].document(glob) case ListPart.Item(p) => document.document(p) - }) + Doc.char(']') + } + ) + Doc.char(']') case Annotation(p, t) => /* * We need to know what package we are in and what imports we depend on here. @@ -741,13 +805,16 @@ object Pattern { // of fields here val cspace = Doc.text(": ") val identDoc = Document[Identifier] - val kvargs = Doc.intercalate(Doc.text(", "), - fields.toList.zip(args) + val kvargs = Doc.intercalate( + Doc.text(", "), + fields.toList + .zip(args) .map { case (StructKind.Style.FieldKind.Explicit(n), adoc) => identDoc.document(n) + cspace + adoc case (StructKind.Style.FieldKind.Implicit(_), adoc) => adoc - }) + } + ) prefix + Doc.text(" {") + kvargs + @@ -768,42 +835,49 @@ object Pattern { } def recordPat[N <: StructKind.NamedKind]( - name: Constructor, - args: NonEmptyList[Either[Bindable, (Bindable, Parsed)]])( - fn: (Constructor, StructKind.Style) => N): PositionalStruct[StructKind, TypeRef] = { + name: Constructor, + args: NonEmptyList[Either[Bindable, (Bindable, Parsed)]] + )( + fn: (Constructor, StructKind.Style) => N + ): PositionalStruct[StructKind, TypeRef] = { val fields = args.map { - case Left(b) => StructKind.Style.FieldKind.Implicit(b) + case Left(b) => StructKind.Style.FieldKind.Implicit(b) case Right((b, _)) => StructKind.Style.FieldKind.Explicit(b) } val structArgs = args.toList.map { - case Left(b) => Pattern.Var(b) + case Left(b) => Pattern.Var(b) case Right((_, p)) => p } - PositionalStruct( - fn(name, StructKind.Style.RecordLike(fields)), - structArgs) + PositionalStruct(fn(name, StructKind.Style.RecordLike(fields)), structArgs) } - def compiledDocument[A: Document]: Document[Pattern[(PackageName, Constructor), A]] = { - lazy val doc: Document[Pattern[(PackageName, Constructor), A]] = compiledDocument[A] + def compiledDocument[A: Document] + : Document[Pattern[(PackageName, Constructor), A]] = { + lazy val doc: Document[Pattern[(PackageName, Constructor), A]] = + compiledDocument[A] Document.instance[Pattern[(PackageName, Constructor), A]] { - case WildCard => Doc.char('_') - case Literal(lit) => Document[Lit].document(lit) - case Var(n) => Document[Identifier].document(n) - case Named(n, u@Union(_, _)) => + case WildCard => Doc.char('_') + case Literal(lit) => Document[Lit].document(lit) + case Var(n) => Document[Identifier].document(n) + case Named(n, u @ Union(_, _)) => // union is also an operator, so we need to use parens to explicitly bind | more tightly // than the as on the left. - Doc.char('(') + doc.document(u) + Doc.char(')') + Doc.text(" as ") + Document[Identifier].document(n) + Doc.char('(') + doc.document(u) + Doc.char(')') + Doc.text( + " as " + ) + Document[Identifier].document(n) case Named(n, p) => doc.document(p) + Doc.text(" as ") + Document[Identifier].document(n) case StrPat(items) => document.document(StrPat(items)) case ListPat(items) => - Doc.char('[') + Doc.intercalate(Doc.text(", "), + Doc.char('[') + Doc.intercalate( + Doc.text(", "), items.map { case ListPart.WildList => Doc.text("*_") - case ListPart.NamedList(glob) => Doc.char('*') + Document[Identifier].document(glob) + case ListPart.NamedList(glob) => + Doc.char('*') + Document[Identifier].document(glob) case ListPart.Item(p) => doc.document(p) - }) + Doc.char(']') + } + ) + Doc.char(']') case Annotation(p, t) => /* * We need to know what package we are in and what imports we depend on here. @@ -814,12 +888,20 @@ object Pattern { * case */ doc.document(p) + Doc.text(": ") + Document[A].document(t) - case ps@PositionalStruct((_, c), a) => - def untuple(p: Pattern[(PackageName, Constructor), A]): Option[List[Doc]] = + case ps @ PositionalStruct((_, c), a) => + def untuple( + p: Pattern[(PackageName, Constructor), A] + ): Option[List[Doc]] = p match { - case PositionalStruct((PackageName.PredefName, Constructor("Unit")), Nil) => + case PositionalStruct( + (PackageName.PredefName, Constructor("Unit")), + Nil + ) => Some(Nil) - case PositionalStruct((PackageName.PredefName, Constructor("TupleCons")), a :: b :: Nil) => + case PositionalStruct( + (PackageName.PredefName, Constructor("TupleCons")), + a :: b :: Nil + ) => untuple(b).map { l => doc.document(a) :: l } case _ => None } @@ -833,7 +915,7 @@ object Pattern { case None => val args = a match { case Nil => Doc.empty - case _ => tup(a.map(doc.document(_))) + case _ => tup(a.map(doc.document(_))) } Doc.text(c.asString) + args } @@ -851,26 +933,27 @@ object Pattern { } } - /** - * For fully typed patterns, compute the type environment of the bindings - * from this pattern. This will sys.error if you pass a bad pattern, which - * you should never do (and this code will never do unless there is some - * broken invariant) - */ - def envOf[C, K, T](p: Pattern[C, T], env: Map[K, T])(kfn: Identifier => K): Map[K, T] = { + /** For fully typed patterns, compute the type environment of the bindings + * from this pattern. This will sys.error if you pass a bad pattern, which + * you should never do (and this code will never do unless there is some + * broken invariant) + */ + def envOf[C, K, T](p: Pattern[C, T], env: Map[K, T])( + kfn: Identifier => K + ): Map[K, T] = { def update(env: Map[K, T], n: Identifier, typeOf: Option[T]): Map[K, T] = - typeOf match { - case None => - // $COVERAGE-OFF$ should be unreachable - sys.error(s"no type found for $n in $p") - // $COVERAGE-ON$ should be unreachable - case Some(t) => env.updated(kfn(n), t) - } + typeOf match { + case None => + // $COVERAGE-OFF$ should be unreachable + sys.error(s"no type found for $n in $p") + // $COVERAGE-ON$ should be unreachable + case Some(t) => env.updated(kfn(n), t) + } def loop(p0: Pattern[C, T], typeOf: Option[T], env: Map[K, T]): Map[K, T] = p0 match { - case WildCard => env + case WildCard => env case Literal(_) => env - case Var(n) => update(env, n, typeOf) + case Var(n) => update(env, n, typeOf) case Named(n, p1) => val e1 = loop(p1, typeOf, env) update(e1, n, typeOf) @@ -879,12 +962,12 @@ object Pattern { items .foldLeft(env) { case (env, StrPart.NamedStr(n)) => update(env, n, typeOf) - case (env, _) => env + case (env, _) => env } case ListPat(items) => items.foldLeft(env) { - case (env, ListPart.WildList) => env + case (env, ListPart.WildList) => env case (env, ListPart.NamedList(n)) => // the type of a named sub-list is // the same as the type of the list @@ -913,37 +996,38 @@ object Pattern { val pname = Identifier.bindableParser.map(StrPart.NamedStr(_)) def strp(q: Char): P[List[StrPart]] = - StringUtil.interpolatedString(q, start, pwild.orElse(pname), end) + StringUtil + .interpolatedString(q, start, pwild.orElse(pname), end) .map(_.map { - case Left(p) => p + case Left(p) => p case Right((_, str)) => StrPart.LitStr(str) }) val eitherString = strp('\'') <+> strp('"') // don't emit complex patterns for simple strings: val str = eitherString.map { - case Nil => Literal(Lit.EmptyStr) + case Nil => Literal(Lit.EmptyStr) case StrPart.LitStr(str) :: Nil => Literal(Lit.Str(str)) - case h :: tail => StrPat(NonEmptyList(h, tail)) + case h :: tail => StrPat(NonEmptyList(h, tail)) } str <+> intp } - /** - * This does not allow a top-level type annotation which would be ambiguous - * with : used for ending the match case block - */ + /** This does not allow a top-level type annotation which would be ambiguous + * with : used for ending the match case block + */ val matchParser: P[Parsed] = P.defer(matchOrNot(isMatch = true)) - /** - * A Pattern in a match position allows top level un-parenthesized type annotation - */ + /** A Pattern in a match position allows top level un-parenthesized type + * annotation + */ val bindParser: P[Parsed] = P.defer(matchOrNot(isMatch = false)) - private val maybePartial: P0[(Constructor, StructKind.Style) => StructKind.NamedKind] = { + private val maybePartial + : P0[(Constructor, StructKind.Style) => StructKind.NamedKind] = { val partial = (maybeSpace.soft ~ P.string("...")).as( { (n: Constructor, s: StructKind.Style) => StructKind.NamedPartial(n, s) } ) @@ -955,24 +1039,30 @@ object Pattern { partial.orElse(notPartial) } - private def parseRecordStruct(recurse: P0[Parsed]): P[Constructor => PositionalStruct[StructKind, TypeRef]] = { + private def parseRecordStruct( + recurse: P0[Parsed] + ): P[Constructor => PositionalStruct[StructKind, TypeRef]] = { // We do maybeSpace, then { } then either a Bindable or Bindable: Pattern // maybe followed by ... val item: P[Either[Bindable, (Bindable, Parsed)]] = - (Identifier.bindableParser ~ ((maybeSpace.soft ~ P.char(':') ~ maybeSpace) *> recurse).?) + (Identifier.bindableParser ~ ((maybeSpace.soft ~ P.char( + ':' + ) ~ maybeSpace) *> recurse).?) .map { - case (b, None) => Left(b) + case (b, None) => Left(b) case (b, Some(pat)) => Right((b, pat)) } val items = item.nonEmptyList ~ maybePartial - ((maybeSpace.with1.soft ~ P.char('{') ~ maybeSpace) *> items <* (maybeSpace ~ P.char('}'))) - .map { case (args, fn) => - { (c: Constructor) => recordPat(c, args)(fn) } - } + ((maybeSpace.with1.soft ~ P.char( + '{' + ) ~ maybeSpace) *> items <* (maybeSpace ~ P.char('}'))) + .map { case (args, fn) => { (c: Constructor) => recordPat(c, args)(fn) } } } - private def parseTupleStruct(recurse: P[Parsed]): P[Constructor => PositionalStruct[StructKind, TypeRef]] = { + private def parseTupleStruct( + recurse: P[Parsed] + ): P[Constructor => PositionalStruct[StructKind, TypeRef]] = { // There are three cases: // Foo(1 or more patterns) // Foo(1 or more patterns, ...) @@ -980,13 +1070,19 @@ object Pattern { val oneOrMore = recurse.nonEmptyList.map(_.toList) ~ maybePartial val onlyPartial = P.string("...").as { - (Nil, { (n: Constructor, s: StructKind.Style) => StructKind.NamedPartial(n, s) }) + ( + Nil, + { (n: Constructor, s: StructKind.Style) => + StructKind.NamedPartial(n, s) + } + ) } - (oneOrMore <+> onlyPartial) - .parensCut - .map { case (args, fn) => - { (n: Constructor) => PositionalStruct(fn(n, StructKind.Style.TupleLike), args) } + (oneOrMore <+> onlyPartial).parensCut + .map { + case (args, fn) => { (n: Constructor) => + PositionalStruct(fn(n, StructKind.Style.TupleLike), args) + } } } @@ -996,30 +1092,35 @@ object Pattern { def isNonUnitTuple(arg: Parsed): Boolean = arg match { case PositionalStruct(StructKind.Tuple, args) => args.nonEmpty - case _ => false + case _ => false } def fromTupleOrParens(e: Either[Parsed, List[Parsed]]): Parsed = e match { - case Right(tup) => tuple(tup) + case Right(tup) => tuple(tup) case Left(parens) => parens } def fromMaybeTupleOrParens(p: MaybeTupleOrParens[Parsed]): Parsed = p match { - case MaybeTupleOrParens.Bare(b) => b + case MaybeTupleOrParens.Bare(b) => b case MaybeTupleOrParens.Parens(p) => p - case MaybeTupleOrParens.Tuple(p) => tuple(p) + case MaybeTupleOrParens.Tuple(p) => tuple(p) } private def matchOrNot(isMatch: Boolean): P[Parsed] = { val recurse = P.defer(bindParser) val positional = - (Identifier.consParser ~ (parseTupleStruct(recurse) <+> parseRecordStruct(recurse)).?) + (Identifier.consParser ~ (parseTupleStruct(recurse) <+> parseRecordStruct( + recurse + )).?) .map { case (n, None) => - PositionalStruct(StructKind.Named(n, StructKind.Style.TupleLike), Nil) + PositionalStruct( + StructKind.Named(n, StructKind.Style.TupleLike), + Nil + ) case (n, Some(fn)) => fn(n) } @@ -1039,10 +1140,16 @@ object Pattern { val pvar = Identifier.bindableParser.map(Var(_)) val nonAnnotated = - P.defer(P.oneOf(plit :: pwild :: tupleOrParens :: positional :: listP :: pvar :: Nil)) + P.defer( + P.oneOf( + plit :: pwild :: tupleOrParens :: positional :: listP :: pvar :: Nil + ) + ) val namedOp: P[Parsed => Parsed] = - ((maybeSpace.with1 *> P.string("as") <* Parser.spaces).backtrack *> Identifier.bindableParser) + ((maybeSpace.with1 *> P.string( + "as" + ) <* Parser.spaces).backtrack *> Identifier.bindableParser) .map { n => { (pat: Parsed) => Named(n, pat) } } @@ -1074,4 +1181,3 @@ object Pattern { else withAs.maybeAp(unionOp.orElse(typeAnnotOp)) } } - diff --git a/core/src/main/scala/org/bykn/bosatsu/Predef.scala b/core/src/main/scala/org/bykn/bosatsu/Predef.scala index 49ad05b35..cd5914d45 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Predef.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Predef.scala @@ -5,16 +5,16 @@ import java.math.BigInteger import language.experimental.macros object Predef { - /** - * Loads a file *at compile time* as a means of embedding - * external files into strings. This lets us avoid resources - * which compilicate matters for scalajs. - */ - private[bosatsu] def loadFileInCompile(file: String): String = macro Macro.loadFileInCompileImpl - - /** - * String representation of the predef - */ + + /** Loads a file *at compile time* as a means of embedding external files into + * strings. This lets us avoid resources which compilicate matters for + * scalajs. + */ + private[bosatsu] def loadFileInCompile(file: String): String = + macro Macro.loadFileInCompileImpl + + /** String representation of the predef + */ val predefString: String = loadFileInCompile("core/src/main/resources/bosatsu/predef.bosatsu") @@ -22,8 +22,7 @@ object Predef { PackageName.PredefName val jvmExternals: Externals = - Externals - .empty + Externals.empty .add(packageName, "add", FfiCall.Fn2(PredefImpl.add(_, _))) .add(packageName, "div", FfiCall.Fn2(PredefImpl.div(_, _))) .add(packageName, "sub", FfiCall.Fn2(PredefImpl.sub(_, _))) @@ -33,12 +32,32 @@ object Predef { .add(packageName, "gcd_Int", FfiCall.Fn2(PredefImpl.gcd_Int(_, _))) .add(packageName, "mod_Int", FfiCall.Fn2(PredefImpl.mod_Int(_, _))) .add(packageName, "int_loop", FfiCall.Fn3(PredefImpl.intLoop(_, _, _))) - .add(packageName, "int_to_String", FfiCall.Fn1(PredefImpl.int_to_String(_))) + .add( + packageName, + "int_to_String", + FfiCall.Fn1(PredefImpl.int_to_String(_)) + ) .add(packageName, "trace", FfiCall.Fn2(PredefImpl.trace(_, _))) - .add(packageName, "string_Order_fn", FfiCall.Fn2(PredefImpl.string_Order_Fn(_, _))) - .add(packageName, "concat_String", FfiCall.Fn1(PredefImpl.concat_String(_))) - .add(packageName, "partition_String", FfiCall.Fn2(PredefImpl.partitionString(_, _))) - .add(packageName, "rpartition_String", FfiCall.Fn2(PredefImpl.rightPartitionString(_, _))) + .add( + packageName, + "string_Order_fn", + FfiCall.Fn2(PredefImpl.string_Order_Fn(_, _)) + ) + .add( + packageName, + "concat_String", + FfiCall.Fn1(PredefImpl.concat_String(_)) + ) + .add( + packageName, + "partition_String", + FfiCall.Fn2(PredefImpl.partitionString(_, _)) + ) + .add( + packageName, + "rpartition_String", + FfiCall.Fn2(PredefImpl.rightPartitionString(_, _)) + ) } object PredefImpl { @@ -48,7 +67,7 @@ object PredefImpl { private def i(a: Value): BigInteger = a match { case VInt(bi) => bi - case _ => sys.error(s"expected integer: $a") + case _ => sys.error(s"expected integer: $a") } def add(a: Value, b: Value): Value = @@ -115,7 +134,7 @@ object PredefImpl { def gcd_Int(a: Value, b: Value): Value = VInt(gcdBigInteger(i(a), i(b))) - //def intLoop(intValue: Int, state: a, fn: Int -> a -> TupleCons[Int, TupleCons[a, Unit]]) -> a + // def intLoop(intValue: Int, state: a, fn: Int -> a -> TupleCons[Int, TupleCons[a, Unit]]) -> a final def intLoop(intValue: Value, state: Value, fn: Value): Value = { val fnT = fn.asFn @@ -129,9 +148,9 @@ object PredefImpl { if (n.compareTo(bi) >= 0) { // we are done in this case nextA - } - else loop(nextI, n, nextA) - case other => sys.error(s"unexpected ill-typed value: at $bi, $state, $other") + } else loop(nextI, n, nextA) + case other => + sys.error(s"unexpected ill-typed value: at $bi, $state, $other") } } @@ -159,17 +178,16 @@ object PredefImpl { case Value.VList(parts) => Value.Str(parts.iterator.map { case Value.Str(s) => s - case other => - //$COVERAGE-OFF$ + case other => + // $COVERAGE-OFF$ sys.error(s"type error: $other") - //$COVERAGE-ON$ - } - .mkString) + // $COVERAGE-ON$ + }.mkString) case other => - //$COVERAGE-OFF$ + // $COVERAGE-OFF$ sys.error(s"type error: $other") - //$COVERAGE-ON$ + // $COVERAGE-ON$ } // return an Option[(String, String)] @@ -182,11 +200,12 @@ object PredefImpl { val idx = argS.indexOf(sepS) if (idx < 0) Value.VOption.none - else Value.VOption.some { - val left = argS.substring(0, idx) - val right = argS.substring(idx + sepS.length) - Value.Tuple(Value.ExternalValue(left), Value.ExternalValue(right)) - } + else + Value.VOption.some { + val left = argS.substring(0, idx) + val right = argS.substring(idx + sepS.length) + Value.Tuple(Value.ExternalValue(left), Value.ExternalValue(right)) + } } } @@ -198,12 +217,12 @@ object PredefImpl { val argS = arg.asExternal.toAny.asInstanceOf[String] val idx = argS.lastIndexOf(sepS) if (idx < 0) Value.VOption.none - else Value.VOption.some { - val left = argS.substring(0, idx) - val right = argS.substring(idx + sepS.length) - Value.Tuple(Value.ExternalValue(left), Value.ExternalValue(right)) - } + else + Value.VOption.some { + val left = argS.substring(0, idx) + val right = argS.substring(idx + sepS.length) + Value.Tuple(Value.ExternalValue(left), Value.ExternalValue(right)) + } } } } - diff --git a/core/src/main/scala/org/bykn/bosatsu/Program.scala b/core/src/main/scala/org/bykn/bosatsu/Program.scala index c1684fa80..f1729c74e 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Program.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Program.scala @@ -3,10 +3,11 @@ package org.bykn.bosatsu import Identifier.Bindable case class Program[+T, +D, +S]( - types: T, - lets: List[(Bindable, RecursionKind, D)], - externalDefs: List[Bindable], - from: S) { + types: T, + lets: List[(Bindable, RecursionKind, D)], + externalDefs: List[Bindable], + from: S +) { private[this] lazy val letMap: Map[Bindable, (RecursionKind, D)] = lets.iterator.map { case (n, r, d) => (n, (r, d)) }.toMap diff --git a/core/src/main/scala/org/bykn/bosatsu/Referant.scala b/core/src/main/scala/org/bykn/bosatsu/Referant.scala index 4a5ad2d40..71cff11c3 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Referant.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Referant.scala @@ -6,20 +6,23 @@ import rankn.{ConstructorFn, DefinedType, Type, TypeEnv} import Identifier.{Constructor => ConstructorName} -/** - * A Referant is something that can be exported or imported after resolving - * Before resolving, imports and exports are just names. - */ +/** A Referant is something that can be exported or imported after resolving + * Before resolving, imports and exports are just names. + */ sealed abstract class Referant[+A] { // if this is a Constructor or DefinedT, return the associated DefinedType def definedType: Option[DefinedType[A]] = this match { - case Referant.Value(_) => None - case Referant.DefinedT(dt) => Some(dt) + case Referant.Value(_) => None + case Referant.DefinedT(dt) => Some(dt) case Referant.Constructor(dt, _) => Some(dt) } - def addTo[A1 >: A](packageName: PackageName, name: Identifier, te: TypeEnv[A1]): TypeEnv[A1] = + def addTo[A1 >: A]( + packageName: PackageName, + name: Identifier, + te: TypeEnv[A1] + ): TypeEnv[A1] = this match { case Referant.Value(t) => te.addExternalValue(packageName, name, t) @@ -33,71 +36,84 @@ sealed abstract class Referant[+A] { object Referant { case class Value(scheme: Type) extends Referant[Nothing] case class DefinedT[A](dtype: DefinedType[A]) extends Referant[A] - case class Constructor[A](dtype: DefinedType[A], fn: ConstructorFn) extends Referant[A] + case class Constructor[A](dtype: DefinedType[A], fn: ConstructorFn) + extends Referant[A] - private def imported[A, B, C](imps: List[Import[A, NonEmptyList[Referant[C]]]])(fn: PartialFunction[Referant[C], B]): Map[Identifier, B] = + private def imported[A, B, C]( + imps: List[Import[A, NonEmptyList[Referant[C]]]] + )(fn: PartialFunction[Referant[C], B]): Map[Identifier, B] = imps.foldLeft(Map.empty[Identifier, B]) { (m0, imp) => m0 ++ Import.locals(imp)(fn) } - def importedTypes[A, B](imps: List[Import[A, NonEmptyList[Referant[B]]]]): Map[Identifier, (PackageName, TypeName)] = - imported(imps) { - case Referant.DefinedT(dt) => (dt.packageName, dt.name) + def importedTypes[A, B]( + imps: List[Import[A, NonEmptyList[Referant[B]]]] + ): Map[Identifier, (PackageName, TypeName)] = + imported(imps) { case Referant.DefinedT(dt) => + (dt.packageName, dt.name) } - /** - * These are all the imported items that may be used in a match - */ - def importedConsNames[A, B](imps: List[Import[A, NonEmptyList[Referant[B]]]]): Map[Identifier, (PackageName, ConstructorName)] = - imported(imps) { - case Referant.Constructor(dt, fn) => (dt.packageName, fn.name) + /** These are all the imported items that may be used in a match + */ + def importedConsNames[A, B]( + imps: List[Import[A, NonEmptyList[Referant[B]]]] + ): Map[Identifier, (PackageName, ConstructorName)] = + imported(imps) { case Referant.Constructor(dt, fn) => + (dt.packageName, fn.name) } - /** - * Fully qualified original names - */ + /** Fully qualified original names + */ def fullyQualifiedImportedValues[A, B]( - imps: List[Import[A, NonEmptyList[Referant[B]]]])(nameOf: A => PackageName)(implicit ev: B <:< Kind.Arg): Map[(PackageName, Identifier), Type] = + imps: List[Import[A, NonEmptyList[Referant[B]]]] + )( + nameOf: A => PackageName + )(implicit ev: B <:< Kind.Arg): Map[(PackageName, Identifier), Type] = imps.iterator.flatMap { item => val pn = nameOf(item.pack) item.items.toList.iterator.flatMap { i => val orig = i.originalName val key = (pn, orig) i.tag.toList.iterator.collect { - case Referant.Value(t) => (key, t) + case Referant.Value(t) => (key, t) case Referant.Constructor(dt, fn) => (key, dt.fnTypeOf(fn)) } } - } - .toMap + }.toMap def typeConstructors[A, B]( - imps: List[Import[A, NonEmptyList[Referant[B]]]]): - Map[(PackageName, ConstructorName), (List[(Type.Var.Bound, B)], List[Type], Type.Const.Defined)] = { - val refs: Iterator[Referant[B]] = imps.iterator.flatMap(_.items.toList.iterator.flatMap(_.tag.toList)) + imps: List[Import[A, NonEmptyList[Referant[B]]]] + ): Map[ + (PackageName, ConstructorName), + (List[(Type.Var.Bound, B)], List[Type], Type.Const.Defined) + ] = { + val refs: Iterator[Referant[B]] = + imps.iterator.flatMap(_.items.toList.iterator.flatMap(_.tag.toList)) refs.collect { case Constructor(dt, fn) => - ((dt.packageName, fn.name), (dt.annotatedTypeParams, fn.args.map(_._2), dt.toTypeConst)) - } - .toMap + ( + (dt.packageName, fn.name), + (dt.annotatedTypeParams, fn.args.map(_._2), dt.toTypeConst) + ) + }.toMap } - /** - * Build the TypeEnv view of the given imports - */ - def importedTypeEnv[A, B](inps: List[Import[A, NonEmptyList[Referant[B]]]])(nameOf: A => PackageName): TypeEnv[B] = - inps.foldLeft((TypeEnv.empty): TypeEnv[B]) { - case (te, imps) => - val pack = nameOf(imps.pack) - imps.items.foldLeft(te) { (te, imp) => - val nm = imp.localName - imp.tag.foldLeft(te) { - case (te1, Referant.Value(t)) => - te1.addExternalValue(pack, nm, t) - case (te1, Referant.Constructor(dt, cf)) => - te1.addConstructor(pack, dt, cf) - case (te1, Referant.DefinedT(dt)) => - te1.addDefinedType(dt) - } + /** Build the TypeEnv view of the given imports + */ + def importedTypeEnv[A, B](inps: List[Import[A, NonEmptyList[Referant[B]]]])( + nameOf: A => PackageName + ): TypeEnv[B] = + inps.foldLeft((TypeEnv.empty): TypeEnv[B]) { case (te, imps) => + val pack = nameOf(imps.pack) + imps.items.foldLeft(te) { (te, imp) => + val nm = imp.localName + imp.tag.foldLeft(te) { + case (te1, Referant.Value(t)) => + te1.addExternalValue(pack, nm, t) + case (te1, Referant.Constructor(dt, cf)) => + te1.addConstructor(pack, dt, cf) + case (te1, Referant.DefinedT(dt)) => + te1.addDefinedType(dt) } + } } } diff --git a/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala b/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala index c6afb9a29..3bf405e33 100644 --- a/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala +++ b/core/src/main/scala/org/bykn/bosatsu/SourceConverter.scala @@ -1,7 +1,7 @@ package org.bykn.bosatsu import cats.{Applicative, Traverse} -import cats.data.{ Chain, Ior, NonEmptyChain, NonEmptyList, State } +import cats.data.{Chain, Ior, NonEmptyChain, NonEmptyList, State} import org.bykn.bosatsu.rankn.{ParsedTypeEnv, Type, TypeEnv} import scala.collection.immutable.SortedSet import scala.collection.mutable.{Map => MMap} @@ -20,14 +20,14 @@ import Declaration._ import SourceConverter.{success, Result} -/** - * Convert a source types (a syntactic expression) into - * the internal representations - */ +/** Convert a source types (a syntactic expression) into the internal + * representations + */ final class SourceConverter( - thisPackage: PackageName, - imports: List[Import[PackageName, NonEmptyList[Referant[Kind.Arg]]]], - localDefs: List[TypeDefinitionStatement]) { + thisPackage: PackageName, + imports: List[Import[PackageName, NonEmptyList[Referant[Kind.Arg]]]], + localDefs: List[TypeDefinitionStatement] +) { /* * We should probably error for non-predef name collisions. * Maybe we should even error even or predef collisions that @@ -37,7 +37,8 @@ final class SourceConverter( private val localConstructors = localDefs.flatMap(_.constructors).toSet private val typeCache: MMap[Constructor, Type.Const] = MMap.empty - private val consCache: MMap[Constructor, (PackageName, Constructor)] = MMap.empty + private val consCache: MMap[Constructor, (PackageName, Constructor)] = + MMap.empty private val importedTypes: Map[Identifier, (PackageName, TypeName)] = Referant.importedTypes(imports) @@ -51,7 +52,10 @@ final class SourceConverter( val importedTypeEnv: TypeEnv[Kind.Arg] = Referant.importedTypeEnv(imports)(identity) - private def nameToType(c: Constructor, region: Region): Result[rankn.Type.Const] = + private def nameToType( + c: Constructor, + region: Region + ): Result[rankn.Type.Const] = typeCache.get(c) match { case Some(r) => success(r) case None => @@ -60,8 +64,7 @@ final class SourceConverter( val res = Type.Const.Defined(thisPackage, tc) typeCache.update(c, res) success(res) - } - else { + } else { importedTypes.get(c) match { case Some((p, t)) => val res = Type.Const.Defined(p, t) @@ -72,23 +75,28 @@ final class SourceConverter( val bestEffort = Type.Const.Defined(thisPackage, tc) SourceConverter.partial( SourceConverter.UnknownTypeName(c, region), - bestEffort) + bestEffort + ) } } } private def nameToCons(c: Constructor): (PackageName, Constructor) = - consCache.getOrElseUpdate(c, { - if (localConstructors(c)) (thisPackage, c) - else resolveImportedCons.getOrElse(c, (thisPackage, c)) - }) + consCache.getOrElseUpdate( + c, { + if (localConstructors(c)) (thisPackage, c) + else resolveImportedCons.getOrElse(c, (thisPackage, c)) + } + ) /* * This ignores the name completely and just returns the lambda expression here */ private def toLambdaExpr[B]( - ds: DefStatement[Pattern.Parsed, B], region: Region, tag: Result[Declaration])( - resultExpr: B => Result[Expr[Declaration]]): Result[Expr[Declaration]] = { + ds: DefStatement[Pattern.Parsed, B], + region: Region, + tag: Result[Declaration] + )(resultExpr: B => Result[Expr[Declaration]]): Result[Expr[Declaration]] = { val unTypedBody = resultExpr(ds.result) val bodyType: Option[Result[Type]] = ds.retType.map(toType(_, region)) @@ -102,7 +110,7 @@ final class SourceConverter( type Pat = Pattern[(PackageName, Constructor), Type] val convertedArgs: Result[NonEmptyList[NonEmptyList[Pat]]] = - travNE2.traverse(ds.args)(convertPattern(_, region)) + travNE2.traverse(ds.args)(convertPattern(_, region)) // If we have the full type of the lambda, apply it. This // helps in recursive cases since we can see at the call site @@ -110,49 +118,63 @@ final class SourceConverter( // was incorrect. Without this, type errors become very non-specific. val maybeFullyTyped: Result[Option[Type]] = (convertedArgs, bodyType.sequence).parMapN { case (args, optResTpe) => - (travNE2.traverse(args)((p: Pat) => p.simpleTypeOf), optResTpe).mapN { case (argsTpe, resTpe) => - argsTpe.toList.foldRight(resTpe) { (args, res) => rankn.Type.Fun(args, res) } + (travNE2.traverse(args)((p: Pat) => p.simpleTypeOf), optResTpe).mapN { + case (argsTpe, resTpe) => + argsTpe.toList.foldRight(resTpe) { (args, res) => + rankn.Type.Fun(args, res) + } } } - (convertedArgs, - bodyExp, - tag, - maybeFullyTyped).parMapN { (groups, b, t, fullType) => - val lambda0 = groups.toList.foldRight(b) { case (as, b) => Expr.buildPatternLambda(as, b, t) } - val lambda = fullType.fold(lambda0)(Expr.Annotation(lambda0, _, t)) - ds.typeArgs match { - case None => success(lambda) - case Some(args) => - val bs = args.map { - case (tr, optK) => - (tr.toBoundVar, optK match { - case None => Kind.Type - case Some(k) => k - }) + (convertedArgs, bodyExp, tag, maybeFullyTyped).parMapN { + (groups, b, t, fullType) => + val lambda0 = groups.toList.foldRight(b) { case (as, b) => + Expr.buildPatternLambda(as, b, t) + } + val lambda = fullType.fold(lambda0)(Expr.Annotation(lambda0, _, t)) + ds.typeArgs match { + case None => success(lambda) + case Some(args) => + val bs = args.map { case (tr, optK) => + ( + tr.toBoundVar, + optK match { + case None => Kind.Type + case Some(k) => k + } + ) } - val gen = Expr.forAll(bs.toList, lambda) - val freeVarsList = Expr.freeBoundTyVars(lambda) - val freeVars = freeVarsList.toSet - val notFreeDecl = bs.exists { case (a, _) => !freeVars(a) } - if (notFreeDecl) { - // we have a lint that fails if declTV is not - // a superset of what you would derive from the args - // the purpose here is to control the *order* of - // and to allow introducing phantom parameters, not - // it is confusing if some are explicit, but some are not - SourceConverter.partial( - SourceConverter.InvalidDefTypeParameters(args, freeVarsList, ds, region), - gen) - } - else success(gen) - } - } - .flatten + val gen = Expr.forAll(bs.toList, lambda) + val freeVarsList = Expr.freeBoundTyVars(lambda) + val freeVars = freeVarsList.toSet + val notFreeDecl = bs.exists { case (a, _) => !freeVars(a) } + if (notFreeDecl) { + // we have a lint that fails if declTV is not + // a superset of what you would derive from the args + // the purpose here is to control the *order* of + // and to allow introducing phantom parameters, not + // it is confusing if some are explicit, but some are not + SourceConverter.partial( + SourceConverter.InvalidDefTypeParameters( + args, + freeVarsList, + ds, + region + ), + gen + ) + } else success(gen) + } + }.flatten } - private def resolveToVar[A](ident: Identifier, decl: A, bound: Set[Bindable], topBound: Set[Bindable]): Expr[A] = + private def resolveToVar[A]( + ident: Identifier, + decl: A, + bound: Set[Bindable], + topBound: Set[Bindable] + ): Expr[A] = ident match { - case c@Constructor(_) => + case c @ Constructor(_) => val (p, cons) = nameToCons(c) Expr.Global(p, cons, decl) case b: Bindable => @@ -160,11 +182,10 @@ final class SourceConverter( else if (topBound(b)) { // local top level bindings can shadow imports after they are imported Expr.Global(thisPackage, b, decl) - } - else { + } else { importedNames.get(ident) match { case Some((p, n)) => Expr.Global(p, n, decl) - case None => + case None => // this is an error, but it will be caught later // at type-checking Expr.Local(b, decl) @@ -172,26 +193,36 @@ final class SourceConverter( } } - private def fromDecl(decl: Declaration, bound: Set[Bindable], topBound: Set[Bindable]): Result[Expr[Declaration]] = { + private def fromDecl( + decl: Declaration, + bound: Set[Bindable], + topBound: Set[Bindable] + ): Result[Expr[Declaration]] = { implicit val parAp = SourceConverter.parallelIor def loop(decl: Declaration) = fromDecl(decl, bound, topBound) - def withBound(decl: Declaration, newB: Iterable[Bindable]) = fromDecl(decl, bound ++ newB, topBound) + def withBound(decl: Declaration, newB: Iterable[Bindable]) = + fromDecl(decl, bound ++ newB, topBound) decl match { case Annotation(term, tpe) => - (loop(term), toType(tpe, decl.region)).parMapN(Expr.Annotation(_, _, decl)) + (loop(term), toType(tpe, decl.region)) + .parMapN(Expr.Annotation(_, _, decl)) case Apply(fn, args, _) => (loop(fn), args.toList.traverse(loop(_))) .parMapN { Expr.buildApp(_, _, decl) } - case ao@ApplyOp(left, op, right) => - val opVar: Expr[Declaration] = resolveToVar(op, ao.opVar, bound, topBound) + case ao @ ApplyOp(left, op, right) => + val opVar: Expr[Declaration] = + resolveToVar(op, ao.opVar, bound, topBound) (loop(left), loop(right)).parMapN { (l, r) => Expr.buildApp(opVar, l :: r :: Nil, decl) } case Binding(BindingStatement(pat, value, Padding(_, rest))) => val erest = withBound(rest, pat.names) - def solvePat(pat: Pattern.Parsed, rrhs: Result[Expr[Declaration]]): Result[Expr[Declaration]] = + def solvePat( + pat: Pattern.Parsed, + rrhs: Result[Expr[Declaration]] + ): Result[Expr[Declaration]] = pat match { case Pattern.Var(arg) => (erest, rrhs).parMapN { (e, rhs) => @@ -204,16 +235,17 @@ final class SourceConverter( solvePat(pat, newRhs) } case Pattern.Named(nm, p) => - // this is the same as creating a let nm = value first + // this is the same as creating a let nm = value first (solvePat(p, rrhs), rrhs).parMapN { (inner, rhs) => Expr.Let(nm, rhs, inner, RecursionKind.NonRecursive, decl) } case pat => // TODO: we need the region on the pattern... - (convertPattern(pat, decl.region - value.region), erest, rrhs).parMapN { (newPattern, e, rhs) => - val expBranches = NonEmptyList.of((newPattern, e)) - Expr.Match(rhs, expBranches, decl) - } + (convertPattern(pat, decl.region - value.region), erest, rrhs) + .parMapN { (newPattern, e, rhs) => + val expBranches = NonEmptyList.of((newPattern, e)) + Expr.Match(rhs, expBranches, decl) + } } solvePat(pat, loop(value)) @@ -221,23 +253,30 @@ final class SourceConverter( loop(decl).map(_.as(decl)) case CommentNB(CommentStatement(_, Padding(_, decl))) => loop(decl).map(_.as(decl)) - case DefFn(defstmt@DefStatement(_, _, _, _, _)) => + case DefFn(defstmt @ DefStatement(_, _, _, _, _)) => val inExpr = defstmt.result match { case (_, Padding(_, in)) => withBound(in, defstmt.name :: Nil) } - val newBindings = defstmt.name :: defstmt.args.toList.flatMap(_.patternNames) - val lambda = toLambdaExpr(defstmt, decl.region, success(decl))({ res => withBound(res._1.get, newBindings) }) + val newBindings = + defstmt.name :: defstmt.args.toList.flatMap(_.patternNames) + val lambda = toLambdaExpr(defstmt, decl.region, success(decl))({ res => + withBound(res._1.get, newBindings) + }) (inExpr, lambda).parMapN { (in, lam) => // We rely on DefRecursionCheck to rule out bad recursions val boundName = defstmt.name val rec = - if (UnusedLetCheck.freeBound(lam).contains(boundName)) RecursionKind.Recursive + if (UnusedLetCheck.freeBound(lam).contains(boundName)) + RecursionKind.Recursive else RecursionKind.NonRecursive Expr.Let(boundName, lam, in, recursive = rec, decl) } case IfElse(ifCases, elseCase) => - def loop0(ifs: NonEmptyList[(Expr[Declaration], Expr[Declaration])], elseC: Expr[Declaration]): Expr[Declaration] = + def loop0( + ifs: NonEmptyList[(Expr[Declaration], Expr[Declaration])], + elseC: Expr[Declaration] + ): Expr[Declaration] = ifs match { case NonEmptyList((cond, ifTrue), Nil) => Expr.ifExpr(cond, ifTrue, elseC, decl) @@ -251,15 +290,19 @@ final class SourceConverter( val else1 = loop(elseCase.get) (if1, else1).parMapN(loop0(_, _)) - case tern@Ternary(t, c, f) => - loop(IfElse(NonEmptyList.one((c, OptIndent.same(t))), OptIndent.same(f))(tern.region)) + case tern @ Ternary(t, c, f) => + loop( + IfElse(NonEmptyList.one((c, OptIndent.same(t))), OptIndent.same(f))( + tern.region + ) + ) case Lambda(args, body) => val argsRes = args.traverse(convertPattern(_, decl.region)) val bodyRes = withBound(body, args.patternNames) (argsRes, bodyRes).parMapN { (args, body) => Expr.buildPatternLambda(args, body, decl) } - case la@LeftApply(_, _, _, _) => + case la @ LeftApply(_, _, _, _) => loop(la.rewrite).map(_.as(decl)) case Literal(lit) => success(Expr.Literal(lit, decl)) @@ -278,20 +321,34 @@ final class SourceConverter( newPattern.product(withBound(decl, pat.names)) } (loop(arg), expBranches).parMapN(Expr.Match(_, _, decl)) - case m@Matches(a, p) => + case m @ Matches(a, p) => // x matches p == // match x: // p: True // _: False - val True: Expr[Declaration] = Expr.Global(PackageName.PredefName, Identifier.Constructor("True"), m) - val False: Expr[Declaration] = Expr.Global(PackageName.PredefName, Identifier.Constructor("False"), m) + val True: Expr[Declaration] = + Expr.Global(PackageName.PredefName, Identifier.Constructor("True"), m) + val False: Expr[Declaration] = Expr.Global( + PackageName.PredefName, + Identifier.Constructor("False"), + m + ) (loop(a), convertPattern(p, m.region)).mapN { (a, p) => - val branches = NonEmptyList((p, True), (Pattern.WildCard, False) :: Nil) + val branches = + NonEmptyList((p, True), (Pattern.WildCard, False) :: Nil) Expr.Match(a, branches, m) } - case tc@TupleCons(its) => - val tup0: Expr[Declaration] = Expr.Global(PackageName.PredefName, Identifier.Constructor("Unit"), tc) - val tup2: Expr[Declaration] = Expr.Global(PackageName.PredefName, Identifier.Constructor("TupleCons"), tc) + case tc @ TupleCons(its) => + val tup0: Expr[Declaration] = Expr.Global( + PackageName.PredefName, + Identifier.Constructor("Unit"), + tc + ) + val tup2: Expr[Declaration] = Expr.Global( + PackageName.PredefName, + Identifier.Constructor("TupleCons"), + tc + ) def tup(args: List[Declaration]): Result[Expr[Declaration]] = args match { case Nil => success(tup0) @@ -304,13 +361,13 @@ final class SourceConverter( } tup(its) - case s@StringDecl(parts) => + case s @ StringDecl(parts) => // a single string item should be converted // to that thing, // two or more should be converted this to concat_String([items]) val decls = parts.map { case Right((r, str)) => Literal(Lit(str))(r) - case Left(decl) => decl + case Left(decl) => decl } decls match { @@ -318,17 +375,22 @@ final class SourceConverter( loop(one) case twoOrMore => val lldecl = - ListDecl(ListLang.Cons(twoOrMore.toList.map(SpliceOrItem.Item(_))))(s.region) + ListDecl( + ListLang.Cons(twoOrMore.toList.map(SpliceOrItem.Item(_))) + )(s.region) loop(lldecl).map { listExpr => - val fnName: Expr[Declaration] = - Expr.Global(PackageName.PredefName, Identifier.Name("concat_String"), s) + Expr.Global( + PackageName.PredefName, + Identifier.Name("concat_String"), + s + ) Expr.buildApp(fnName, listExpr.as(s: Declaration) :: Nil, s) } } - case l@ListDecl(list) => + case l @ ListDecl(list) => list match { case ListLang.Cons(items) => val revDecs: Result[List[SpliceOrItem[Expr[Declaration]]]] = @@ -346,10 +408,16 @@ final class SourceConverter( Expr.Global(pn, Identifier.Name(c), l) val Empty: Expr[Declaration] = mkC("EmptyList") - def cons(head: Expr[Declaration], tail: Expr[Declaration]): Expr[Declaration] = + def cons( + head: Expr[Declaration], + tail: Expr[Declaration] + ): Expr[Declaration] = Expr.buildApp(mkC("NonEmptyList"), head :: tail :: Nil, l) - def concat(headList: Expr[Declaration], tail: Expr[Declaration]): Expr[Declaration] = + def concat( + headList: Expr[Declaration], + tail: Expr[Declaration] + ): Expr[Declaration] = Expr.buildApp(mkN("concat"), headList :: tail :: Nil, l) revDecs.map(_.foldLeft(Empty) { @@ -392,10 +460,11 @@ final class SourceConverter( "flat_map_List" } val newBound = binding.names - val opExpr: Expr[Declaration] = Expr.Global(pn, Identifier.Name(opName), l) + val opExpr: Expr[Declaration] = + Expr.Global(pn, Identifier.Name(opName), l) val resExpr: Result[Expr[Declaration]] = filter match { - case None => withBound(res.value, newBound) + case None => withBound(res.value, newBound) case Some(cond) => // To do filters, we lift all results into lists, // so single items must be made singleton lists @@ -407,9 +476,14 @@ final class SourceConverter( // here we lift the result into a a singleton list withBound(r, newBound).map { ritem => Expr.App( - Expr.Global(pn, Identifier.Constructor("NonEmptyList"), rdec), + Expr.Global( + pn, + Identifier.Constructor("NonEmptyList"), + rdec + ), NonEmptyList(ritem, empty :: Nil), - rdec) + rdec + ) } case SpliceOrItem.Splice(r) => withBound(r, newBound) } @@ -418,34 +492,40 @@ final class SourceConverter( Expr.ifExpr(c, sing, empty, cond) } } - (convertPattern(binding, decl.region), - resExpr, - loop(in)).mapN { (newPattern, resExpr, in) => - val fnExpr: Expr[Declaration] = - Expr.buildPatternLambda(NonEmptyList.of(newPattern), resExpr, l) - Expr.buildApp(opExpr, in :: fnExpr :: Nil, l) + (convertPattern(binding, decl.region), resExpr, loop(in)).mapN { + (newPattern, resExpr, in) => + val fnExpr: Expr[Declaration] = + Expr.buildPatternLambda( + NonEmptyList.of(newPattern), + resExpr, + l + ) + Expr.buildApp(opExpr, in :: fnExpr :: Nil, l) } } - case l@DictDecl(dict) => + case l @ DictDecl(dict) => val pn = PackageName.PredefName def mkN(n: String): Expr[Declaration] = Expr.Global(pn, Identifier.Name(n), l) val empty: Expr[Declaration] = Expr.App(mkN("empty_Dict"), NonEmptyList.one(mkN("string_Order")), l) - def add(dict: Expr[Declaration], k: Expr[Declaration], v: Expr[Declaration]): Expr[Declaration] = { + def add( + dict: Expr[Declaration], + k: Expr[Declaration], + v: Expr[Declaration] + ): Expr[Declaration] = { val fn = mkN("add_key") Expr.buildApp(fn, dict :: k :: v :: Nil, l) } dict match { case ListLang.Cons(items) => val revDecs: Result[List[KVPair[Expr[Declaration]]]] = - items.reverse.traverse { - case KVPair(k, v) => - (loop(k), loop(v)).mapN(KVPair(_, _)) + items.reverse.traverse { case KVPair(k, v) => + (loop(k), loop(v)).mapN(KVPair(_, _)) } - revDecs.map(_.foldLeft(empty) { - case (dict, KVPair(k, v)) => add(dict, k, v) + revDecs.map(_.foldLeft(empty) { case (dict, KVPair(k, v)) => + add(dict, k, v) }) case ListLang.Comprehension(KVPair(k, v), binding, in, filter) => /* @@ -463,19 +543,19 @@ final class SourceConverter( val newBound = binding.names val pn = PackageName.PredefName - val opExpr: Expr[Declaration] = Expr.Global(pn, Identifier.Name("foldLeft"), l) + val opExpr: Expr[Declaration] = + Expr.Global(pn, Identifier.Name("foldLeft"), l) val dictSymbol = unusedNames(decl.allNames).next() val init: Expr[Declaration] = Expr.Local(dictSymbol, l) - val added = (withBound(k, newBound), withBound(v, newBound)).mapN(add(init, _, _)) + val added = (withBound(k, newBound), withBound(v, newBound)).mapN( + add(init, _, _) + ) val resExpr: Result[Expr[Declaration]] = filter match { case None => added case Some(cond0) => (added, withBound(cond0, newBound)).mapN { (added, cond) => - Expr.ifExpr(cond, - added, - init, - cond0) + Expr.ifExpr(cond, added, init, cond0) } } val newPattern = convertPattern(binding, decl.region) @@ -484,67 +564,83 @@ final class SourceConverter( Expr.buildPatternLambda( NonEmptyList(Pattern.Var(dictSymbol), pat :: Nil), res, - l) + l + ) Expr.buildApp(opExpr, in :: empty :: foldFn :: Nil, l) } - } - case rc@RecordConstructor(name, args) => - val (p, c) = nameToCons(name) - val cons: Expr[Declaration] = Expr.Global(p, c, rc) - localTypeEnv.flatMap(_.getConstructorParams(p, c) match { - case Some(params) => - def argExpr(arg: RecordArg): (Bindable, Result[Expr[Declaration]]) = - arg match { - case RecordArg.Simple(b) => - (b, success(resolveToVar(b, rc, bound, topBound))) - case RecordArg.Pair(k, v) => - (k, loop(v)) - } - val mappingList = args.toList.map(argExpr) - val mapping = mappingList.toMap - - lazy val present = - mappingList - .iterator - .map(_._1) - .foldLeft(SortedSet.empty[Bindable])(_ + _) - - def get(b: Bindable): Result[Expr[Declaration]] = - mapping.get(b) match { - case Some(expr) => expr - case None => - SourceConverter.failure( - SourceConverter.MissingArg(name, rc, present, b, rc.region)) - } - val exprArgs = params.traverse { case (b, _) => get(b) } - - val res = exprArgs.map { args => - Expr.buildApp(cons, args.toList, rc) + } + case rc @ RecordConstructor(name, args) => + val (p, c) = nameToCons(name) + val cons: Expr[Declaration] = Expr.Global(p, c, rc) + localTypeEnv.flatMap(_.getConstructorParams(p, c) match { + case Some(params) => + def argExpr(arg: RecordArg): (Bindable, Result[Expr[Declaration]]) = + arg match { + case RecordArg.Simple(b) => + (b, success(resolveToVar(b, rc, bound, topBound))) + case RecordArg.Pair(k, v) => + (k, loop(v)) } - // we also need to check that there are no unused or duplicated - // fields - val paramNamesList = params.map(_._1) - val paramNames = paramNamesList.toSet - // here are all the fields we don't understand - val extra = mappingList.collect { case (k, _) if !paramNames(k) => k } - // Check that the mapping is exactly the right size - NonEmptyList.fromList(extra) match { - case None => res - case Some(extra) => - SourceConverter - .addError(res, - SourceConverter.UnexpectedField(name, rc, extra, paramNamesList, rc.region)) + val mappingList = args.toList.map(argExpr) + val mapping = mappingList.toMap + + lazy val present = + mappingList.iterator + .map(_._1) + .foldLeft(SortedSet.empty[Bindable])(_ + _) + + def get(b: Bindable): Result[Expr[Declaration]] = + mapping.get(b) match { + case Some(expr) => expr + case None => + SourceConverter.failure( + SourceConverter.MissingArg(name, rc, present, b, rc.region) + ) } - case None => - SourceConverter.failure(SourceConverter.UnknownConstructor(name, rc, decl.region)) - }) + val exprArgs = params.traverse { case (b, _) => get(b) } + + val res = exprArgs.map { args => + Expr.buildApp(cons, args.toList, rc) + } + // we also need to check that there are no unused or duplicated + // fields + val paramNamesList = params.map(_._1) + val paramNames = paramNamesList.toSet + // here are all the fields we don't understand + val extra = mappingList.collect { + case (k, _) if !paramNames(k) => k + } + // Check that the mapping is exactly the right size + NonEmptyList.fromList(extra) match { + case None => res + case Some(extra) => + SourceConverter + .addError( + res, + SourceConverter.UnexpectedField( + name, + rc, + extra, + paramNamesList, + rc.region + ) + ) + } + case None => + SourceConverter.failure( + SourceConverter.UnknownConstructor(name, rc, decl.region) + ) + }) } } private def toType(t: TypeRef, region: Region): Result[Type] = TypeRefConverter[Result](t)(nameToType(_, region)) - def toDefinition(pname: PackageName, tds: TypeDefinitionStatement): Result[rankn.DefinedType[Option[Kind.Arg]]] = { + def toDefinition( + pname: PackageName, + tds: TypeDefinitionStatement + ): Result[rankn.DefinedType[Option[Kind.Arg]]] = { import Statement._ type StT = ((Set[Type.TyVar], List[Type.TyVar]), LazyList[Type.TyVar]) @@ -571,114 +667,144 @@ final class SourceConverter( Type.freeTyVars(pt).map(Type.TyVar(_)) } - def buildParams(args: List[(Bindable, Option[Type])]): VarState[List[(Bindable, Type)]] = + def buildParams( + args: List[(Bindable, Option[Type])] + ): VarState[List[(Bindable, Type)]] = args.traverse(buildParam _) // This is a traverse on List[(Bindable, Option[A])] - val deep = Traverse[List].compose(Traverse[(Bindable, *)]).compose(Traverse[Option]) + val deep = + Traverse[List].compose(Traverse[(Bindable, *)]).compose(Traverse[Option]) def updateInferedWithDecl( - typeArgs: Option[NonEmptyList[(TypeRef.TypeVar, Option[Kind.Arg])]], - typeParams0: List[Type.Var.Bound]): Result[List[(Type.Var.Bound, Option[Kind.Arg])]] = - typeArgs match { - case None => success(typeParams0.map((_, None))) - case Some(decl) => - val neBound = decl.map { case (v, k) => (v.toBoundVar, k) } - val declSet = neBound.toList.iterator.map(_._1).toSet - val missingFromDecl = typeParams0.filterNot(declSet) - if ((declSet.size != neBound.size) || missingFromDecl.nonEmpty) { - val bestEffort = neBound.toList.distinctBy(_._1) ::: missingFromDecl.map((_, None)) - // we have a lint that fails if declTV is not - // a superset of what you would derive from the args - // the purpose here is to control the *order* of - // and to allow introducing phantom parameters, not - // it is confusing if some are explicit, but some are not - SourceConverter.partial( - SourceConverter.InvalidTypeParameters(decl, typeParams0, tds), - bestEffort) - } - else success(neBound.toList ::: missingFromDecl.map((_, None))) - } + typeArgs: Option[NonEmptyList[(TypeRef.TypeVar, Option[Kind.Arg])]], + typeParams0: List[Type.Var.Bound] + ): Result[List[(Type.Var.Bound, Option[Kind.Arg])]] = + typeArgs match { + case None => success(typeParams0.map((_, None))) + case Some(decl) => + val neBound = decl.map { case (v, k) => (v.toBoundVar, k) } + val declSet = neBound.toList.iterator.map(_._1).toSet + val missingFromDecl = typeParams0.filterNot(declSet) + if ((declSet.size != neBound.size) || missingFromDecl.nonEmpty) { + val bestEffort = + neBound.toList.distinctBy(_._1) ::: missingFromDecl.map((_, None)) + // we have a lint that fails if declTV is not + // a superset of what you would derive from the args + // the purpose here is to control the *order* of + // and to allow introducing phantom parameters, not + // it is confusing if some are explicit, but some are not + SourceConverter.partial( + SourceConverter.InvalidTypeParameters(decl, typeParams0, tds), + bestEffort + ) + } else success(neBound.toList ::: missingFromDecl.map((_, None))) + } - def validateArgCount(nm: Constructor, args: Int, region: Region): Result[Unit] = + def validateArgCount( + nm: Constructor, + args: Int, + region: Region + ): Result[Unit] = if (args <= Type.FnType.MaxSize) SourceConverter.successUnit - else SourceConverter.partial( - SourceConverter.TooManyConstructorArgs(nm, args, Type.FnType.MaxSize, region), ()) + else + SourceConverter.partial( + SourceConverter + .TooManyConstructorArgs(nm, args, Type.FnType.MaxSize, region), + () + ) // TODO we have to make sure we don't have more than 8 arguments to a struct // or the constructor Fn won't be a valid function tds match { case Struct(nm, typeArgs, args) => validateArgCount(nm, args.length, tds.region) *> - deep.traverse(args)(toType(_, tds.region)) - .flatMap { argsType => - val declVars = typeArgs.iterator.flatMap(_.toList).map { p => Type.TyVar(p._1.toBoundVar) } - val initVars = existingVars(argsType) - val initState = ((initVars.toSet ++ declVars, initVars.reverse), Type.allBinders.map(Type.TyVar)) - val (((_, typeVars), _), params) = buildParams(argsType).run(initState).value + deep + .traverse(args)(toType(_, tds.region)) + .flatMap { argsType => + val declVars = typeArgs.iterator.flatMap(_.toList).map { p => + Type.TyVar(p._1.toBoundVar) + } + val initVars = existingVars(argsType) + val initState = ( + (initVars.toSet ++ declVars, initVars.reverse), + Type.allBinders.map(Type.TyVar) + ) + val (((_, typeVars), _), params) = + buildParams(argsType).run(initState).value + // we reverse to make sure we see in traversal order + val typeParams0 = reverseMap(typeVars) { tv => + tv.toVar match { + case b @ Type.Var.Bound(_) => b + // $COVERAGE-OFF$ this should be unreachable + case unexpected => + sys.error( + s"unexpectedly parsed a non bound var: $unexpected" + ) + // $COVERAGE-ON$ + } + } + + updateInferedWithDecl(typeArgs, typeParams0).map { typeParams => + val tname = TypeName(nm) + val consFn = rankn.ConstructorFn(nm, params) + + rankn.DefinedType(pname, tname, typeParams, consFn :: Nil) + } + } + case Enum(nm, typeArgs, items) => + items.get + .traverse { case (nm, args) => + validateArgCount(nm, args.length, tds.region) *> + deep + .traverse(args)(toType(_, tds.region)) + .map((nm, _)) + } + .flatMap { conArgs => + val constructorsS = conArgs.traverse { case (nm, argsType) => + buildParams(argsType).map { params => + (nm, params) + } + } + val declVars = typeArgs.iterator.flatMap(_.toList).map { p => + Type.TyVar(p._1.toBoundVar) + } + val initVars = existingVars(conArgs.toList.flatMap(_._2)) + val initState = ( + (initVars.toSet ++ declVars, initVars.reverse), + Type.allBinders.map(Type.TyVar) + ) + val (((_, typeVars), _), constructors) = + constructorsS.run(initState).value // we reverse to make sure we see in traversal order val typeParams0 = reverseMap(typeVars) { tv => tv.toVar match { - case b@Type.Var.Bound(_) => b + case b @ Type.Var.Bound(_) => b // $COVERAGE-OFF$ this should be unreachable case unexpected => sys.error(s"unexpectedly parsed a non bound var: $unexpected") // $COVERAGE-ON$ } } - updateInferedWithDecl(typeArgs, typeParams0).map { typeParams => - val tname = TypeName(nm) - val consFn = rankn.ConstructorFn(nm, params) - - rankn.DefinedType(pname, - tname, - typeParams, - consFn :: Nil) - } - } - case Enum(nm, typeArgs, items) => - items.get.traverse { case (nm, args) => - validateArgCount(nm, args.length, tds.region) *> - deep.traverse(args)(toType(_, tds.region)) - .map((nm, _)) - } - .flatMap { conArgs => - - val constructorsS = conArgs.traverse { case (nm, argsType) => - buildParams(argsType).map { params => - (nm, params) - } - } - val declVars = typeArgs.iterator.flatMap(_.toList).map { p => Type.TyVar(p._1.toBoundVar) } - val initVars = existingVars(conArgs.toList.flatMap(_._2)) - val initState = ((initVars.toSet ++ declVars, initVars.reverse), Type.allBinders.map(Type.TyVar)) - val (((_, typeVars), _), constructors) = constructorsS.run(initState).value - // we reverse to make sure we see in traversal order - val typeParams0 = reverseMap(typeVars) { tv => - tv.toVar match { - case b@Type.Var.Bound(_) => b - // $COVERAGE-OFF$ this should be unreachable - case unexpected => sys.error(s"unexpectedly parsed a non bound var: $unexpected") - // $COVERAGE-ON$ - } - } - updateInferedWithDecl(typeArgs, typeParams0).map { typeParams => - val finalCons = constructors.toList.map { case (c, params) => - rankn.ConstructorFn(c, params) + val finalCons = constructors.toList.map { case (c, params) => + rankn.ConstructorFn(c, params) + } + rankn.DefinedType(pname, TypeName(nm), typeParams, finalCons) } - rankn.DefinedType(pname, TypeName(nm), typeParams, finalCons) } - } case ExternalStruct(nm, targs) => // TODO make a real check here of allowed kinds success( rankn.DefinedType( pname, TypeName(nm), - targs.map { case (TypeRef.TypeVar(v), optK) => (Type.Var.Bound(v), optK) }, - Nil) + targs.map { case (TypeRef.TypeVar(v), optK) => + (Type.Var.Bound(v), optK) + }, + Nil ) + ) } } @@ -694,25 +820,44 @@ final class SourceConverter( loop(as, Nil) } - private def convertPattern(pat: Pattern.Parsed, region: Region): Result[Pattern[(PackageName, Constructor), rankn.Type]] = { + private def convertPattern( + pat: Pattern.Parsed, + region: Region + ): Result[Pattern[(PackageName, Constructor), rankn.Type]] = { val nonTupled = unTuplePattern(pat, region) val collisions = pat.collisionBinds NonEmptyList.fromList(collisions) match { case None => nonTupled case Some(nel) => - SourceConverter.addError(nonTupled, SourceConverter.PatternShadow(nel, pat, region)) + SourceConverter.addError( + nonTupled, + SourceConverter.PatternShadow(nel, pat, region) + ) } } - private[this] val empty = Pattern.PositionalStruct((PackageName.PredefName, Constructor("EmptyList")), Nil) - private[this] val nonEmpty = (PackageName.PredefName, Constructor("NonEmptyList")) - - /** - * As much as possible, convert a list pattern into a normal enum pattern which simplifies - * matching, and possibly allows us to more easily statically remove more of the match - */ - private def unlistPattern(parts: List[Pattern.ListPart[Pattern[(PackageName, Constructor), rankn.Type]]]): Pattern[(PackageName, Constructor), rankn.Type] = { - def loop(parts: List[Pattern.ListPart[Pattern[(PackageName, Constructor), rankn.Type]]], topLevel: Boolean): Pattern[(PackageName, Constructor), rankn.Type] = + private[this] val empty = Pattern.PositionalStruct( + (PackageName.PredefName, Constructor("EmptyList")), + Nil + ) + private[this] val nonEmpty = + (PackageName.PredefName, Constructor("NonEmptyList")) + + /** As much as possible, convert a list pattern into a normal enum pattern + * which simplifies matching, and possibly allows us to more easily + * statically remove more of the match + */ + private def unlistPattern( + parts: List[ + Pattern.ListPart[Pattern[(PackageName, Constructor), rankn.Type]] + ] + ): Pattern[(PackageName, Constructor), rankn.Type] = { + def loop( + parts: List[ + Pattern.ListPart[Pattern[(PackageName, Constructor), rankn.Type]] + ], + topLevel: Boolean + ): Pattern[(PackageName, Constructor), rankn.Type] = parts match { case Nil => empty case Pattern.ListPart.Item(h) :: tail => @@ -724,8 +869,7 @@ final class SourceConverter( // changing to _ would allow more things to typecheck, which we can't do // and we can't annotate because we don't know the type of the list Pattern.ListPat(parts) - } - else { + } else { // we are already in the tail of a list, so we can just put _ here Pattern.WildCard } @@ -735,12 +879,13 @@ final class SourceConverter( // changing to _ would allow more things to typecheck, which we can't do // and we can't annotate because we don't know the type of the list Pattern.ListPat(parts) - } - else { + } else { // we are already in the tail of a list, so we can just put n here Pattern.Var(n) } - case (Pattern.ListPart.WildList :: (i@Pattern.ListPart.Item(Pattern.WildCard)) :: tail) => + case (Pattern.ListPart.WildList :: (i @ Pattern.ListPart.Item( + Pattern.WildCard + )) :: tail) => // [*_, _, x...] = [_, *_, x...] loop(i :: Pattern.ListPart.WildList :: tail, topLevel) case (Pattern.ListPart.WildList | Pattern.ListPart.NamedList(_)) :: _ => @@ -751,152 +896,229 @@ final class SourceConverter( loop(parts, true) } - /** - * Tuples are converted into standard types using an HList strategy - */ - private def unTuplePattern(pat: Pattern.Parsed, region: Region): Result[Pattern[(PackageName, Constructor), rankn.Type]] = - pat.traversePattern[Result, (PackageName, Constructor), rankn.Type]({ - case (Pattern.StructKind.Tuple, args) => - // this is a tuple pattern - def loop[A](args: List[Pattern[(PackageName, Constructor), A]]): Pattern[(PackageName, Constructor), A] = - args match { - case Nil => - // () - Pattern.PositionalStruct( - (PackageName.PredefName, Constructor("Unit")), - Nil) - case h :: tail => - val tailP = loop(tail) - Pattern.PositionalStruct( - (PackageName.PredefName, Constructor("TupleCons")), - h :: tailP :: Nil) - } + /** Tuples are converted into standard types using an HList strategy + */ + private def unTuplePattern( + pat: Pattern.Parsed, + region: Region + ): Result[Pattern[(PackageName, Constructor), rankn.Type]] = + pat.traversePattern[Result, (PackageName, Constructor), rankn.Type]( + { + case (Pattern.StructKind.Tuple, args) => + // this is a tuple pattern + def loop[A]( + args: List[Pattern[(PackageName, Constructor), A]] + ): Pattern[(PackageName, Constructor), A] = + args match { + case Nil => + // () + Pattern.PositionalStruct( + (PackageName.PredefName, Constructor("Unit")), + Nil + ) + case h :: tail => + val tailP = loop(tail) + Pattern.PositionalStruct( + (PackageName.PredefName, Constructor("TupleCons")), + h :: tailP :: Nil + ) + } - args.map(loop(_)) - case (Pattern.StructKind.Named(nm, Pattern.StructKind.Style.TupleLike), rargs) => - rargs.flatMap { args => - val pc@(p, c) = nameToCons(nm) - localTypeEnv.flatMap(_.getConstructorParams(p, c) match { - case Some(params) => - val argLen = args.size - val paramLen = params.size - if (argLen == paramLen) { - SourceConverter.success(Pattern.PositionalStruct(pc, args)) - } - else { - // do the best we can - val fixedArgs = (args ::: List.fill(paramLen - argLen)(Pattern.WildCard)).take(paramLen) - SourceConverter.partial( - SourceConverter.InvalidArgCount(nm, pat, argLen, paramLen, region), - Pattern.PositionalStruct(pc, fixedArgs)) - } - case None => - SourceConverter.failure(SourceConverter.UnknownConstructor(nm, pat, region)) - }) - } - case (Pattern.StructKind.NamedPartial(nm, Pattern.StructKind.Style.TupleLike), rargs) => - rargs.flatMap { args => - val pc@(p, c) = nameToCons(nm) - localTypeEnv.flatMap(_.getConstructorParams(p, c) match { - case Some(params) => - val argLen = args.size - val paramLen = params.size - if (argLen <= paramLen) { - val fixedArgs = if (argLen < paramLen) (args ::: List.fill(paramLen - argLen)(Pattern.WildCard)) else args - SourceConverter.success(Pattern.PositionalStruct(pc, fixedArgs)) - } - else { - // we have too many - val fixedArgs = args.take(paramLen) - SourceConverter.partial( - SourceConverter.InvalidArgCount(nm, pat, argLen, paramLen, region), - Pattern.PositionalStruct(pc, fixedArgs)) - } - case None => - SourceConverter.failure(SourceConverter.UnknownConstructor(nm, pat, region)) - }) - } - case (Pattern.StructKind.Named(nm, Pattern.StructKind.Style.RecordLike(fs)), rargs) => - rargs.flatMap { args => - val pc@(p, c) = nameToCons(nm) - localTypeEnv.flatMap(_.getConstructorParams(p, c) match { - case Some(params) => - val mapping = fs.toList.iterator.map(_.field).zip(args.iterator).toMap - lazy val present = SortedSet(fs.toList.iterator.map(_.field).toList: _*) - def get(b: Bindable): Result[Pattern[(PackageName, Constructor), rankn.Type]] = - mapping.get(b) match { - case Some(pat) => - SourceConverter.success(pat) - case None => - SourceConverter.partial(SourceConverter.MissingArg(nm, pat, present, b, region), Pattern.WildCard) + args.map(loop(_)) + case ( + Pattern.StructKind.Named(nm, Pattern.StructKind.Style.TupleLike), + rargs + ) => + rargs.flatMap { args => + val pc @ (p, c) = nameToCons(nm) + localTypeEnv.flatMap(_.getConstructorParams(p, c) match { + case Some(params) => + val argLen = args.size + val paramLen = params.size + if (argLen == paramLen) { + SourceConverter.success(Pattern.PositionalStruct(pc, args)) + } else { + // do the best we can + val fixedArgs = + (args ::: List.fill(paramLen - argLen)(Pattern.WildCard)) + .take(paramLen) + SourceConverter.partial( + SourceConverter + .InvalidArgCount(nm, pat, argLen, paramLen, region), + Pattern.PositionalStruct(pc, fixedArgs) + ) } - val mapped = - params - .traverse { case (b, _) => get(b) }(SourceConverter.parallelIor) - .map(Pattern.PositionalStruct(pc, _)) - - val paramNamesList = params.map(_._1) - val paramNames = paramNamesList.toSet - // here are all the fields we don't understand - val extra = fs.toList.iterator.map(_.field).filterNot(paramNames).toList - // Check that the mapping is exactly the right size - NonEmptyList.fromList(extra) match { - case None => mapped - case Some(extra) => - SourceConverter - .addError(mapped, - SourceConverter.UnexpectedField(nm, pat, extra, paramNamesList, region)) - } - case None => - SourceConverter.failure(SourceConverter.UnknownConstructor(nm, pat, region)) - }) - } - case (Pattern.StructKind.NamedPartial(nm, Pattern.StructKind.Style.RecordLike(fs)), rargs) => - rargs.flatMap { args => - val pc@(p, c) = nameToCons(nm) - localTypeEnv.flatMap(_.getConstructorParams(p, c) match { - case Some(params) => - val mapping = fs.toList.iterator.map(_.field).zip(args.iterator).toMap - def get(b: Bindable): Pattern[(PackageName, Constructor), rankn.Type] = - mapping.get(b) match { - case Some(pat) => pat - case None => Pattern.WildCard + case None => + SourceConverter.failure( + SourceConverter.UnknownConstructor(nm, pat, region) + ) + }) + } + case ( + Pattern.StructKind + .NamedPartial(nm, Pattern.StructKind.Style.TupleLike), + rargs + ) => + rargs.flatMap { args => + val pc @ (p, c) = nameToCons(nm) + localTypeEnv.flatMap(_.getConstructorParams(p, c) match { + case Some(params) => + val argLen = args.size + val paramLen = params.size + if (argLen <= paramLen) { + val fixedArgs = + if (argLen < paramLen) + (args ::: List.fill(paramLen - argLen)(Pattern.WildCard)) + else args + SourceConverter.success( + Pattern.PositionalStruct(pc, fixedArgs) + ) + } else { + // we have too many + val fixedArgs = args.take(paramLen) + SourceConverter.partial( + SourceConverter + .InvalidArgCount(nm, pat, argLen, paramLen, region), + Pattern.PositionalStruct(pc, fixedArgs) + ) } - val derefArgs = params.map { case (b, _) => get(b) } - val res0 = SourceConverter.success(Pattern.PositionalStruct(pc, derefArgs)) - - val paramNamesList = params.map(_._1) - val paramNames = paramNamesList.toSet - // here are all the fields we don't understand - val extra = fs.toList.iterator.map(_.field).filterNot(paramNames).toList - // Check that the mapping is exactly the right size - NonEmptyList.fromList(extra) match { - case None => res0 - case Some(extra) => - SourceConverter - .addError(res0, - SourceConverter.UnexpectedField(nm, pat, extra, paramNamesList, region)) - } - case None => - SourceConverter.failure(SourceConverter.UnknownConstructor(nm, pat, region)) - }) - } + case None => + SourceConverter.failure( + SourceConverter.UnknownConstructor(nm, pat, region) + ) + }) + } + case ( + Pattern.StructKind + .Named(nm, Pattern.StructKind.Style.RecordLike(fs)), + rargs + ) => + rargs.flatMap { args => + val pc @ (p, c) = nameToCons(nm) + localTypeEnv.flatMap(_.getConstructorParams(p, c) match { + case Some(params) => + val mapping = + fs.toList.iterator.map(_.field).zip(args.iterator).toMap + lazy val present = + SortedSet(fs.toList.iterator.map(_.field).toList: _*) + def get( + b: Bindable + ): Result[Pattern[(PackageName, Constructor), rankn.Type]] = + mapping.get(b) match { + case Some(pat) => + SourceConverter.success(pat) + case None => + SourceConverter.partial( + SourceConverter.MissingArg(nm, pat, present, b, region), + Pattern.WildCard + ) + } + val mapped = + params + .traverse { case (b, _) => get(b) }( + SourceConverter.parallelIor + ) + .map(Pattern.PositionalStruct(pc, _)) + + val paramNamesList = params.map(_._1) + val paramNames = paramNamesList.toSet + // here are all the fields we don't understand + val extra = + fs.toList.iterator.map(_.field).filterNot(paramNames).toList + // Check that the mapping is exactly the right size + NonEmptyList.fromList(extra) match { + case None => mapped + case Some(extra) => + SourceConverter + .addError( + mapped, + SourceConverter.UnexpectedField( + nm, + pat, + extra, + paramNamesList, + region + ) + ) + } + case None => + SourceConverter.failure( + SourceConverter.UnknownConstructor(nm, pat, region) + ) + }) + } + case ( + Pattern.StructKind + .NamedPartial(nm, Pattern.StructKind.Style.RecordLike(fs)), + rargs + ) => + rargs.flatMap { args => + val pc @ (p, c) = nameToCons(nm) + localTypeEnv.flatMap(_.getConstructorParams(p, c) match { + case Some(params) => + val mapping = + fs.toList.iterator.map(_.field).zip(args.iterator).toMap + def get( + b: Bindable + ): Pattern[(PackageName, Constructor), rankn.Type] = + mapping.get(b) match { + case Some(pat) => pat + case None => Pattern.WildCard + } + val derefArgs = params.map { case (b, _) => get(b) } + val res0 = SourceConverter.success( + Pattern.PositionalStruct(pc, derefArgs) + ) + + val paramNamesList = params.map(_._1) + val paramNames = paramNamesList.toSet + // here are all the fields we don't understand + val extra = + fs.toList.iterator.map(_.field).filterNot(paramNames).toList + // Check that the mapping is exactly the right size + NonEmptyList.fromList(extra) match { + case None => res0 + case Some(extra) => + SourceConverter + .addError( + res0, + SourceConverter.UnexpectedField( + nm, + pat, + extra, + paramNamesList, + region + ) + ) + } + case None => + SourceConverter.failure( + SourceConverter.UnknownConstructor(nm, pat, region) + ) + }) + } }, { t => toType(t, region) }, { items => items.map(unlistPattern) } - )(SourceConverter.parallelIor) // use the parallel, not the default Applicative which is Monadic + )( + SourceConverter.parallelIor + ) // use the parallel, not the default Applicative which is Monadic private lazy val toTypeEnv: Result[ParsedTypeEnv[Option[Kind.Arg]]] = { val sunit = success(()) - val dupTypes = localDefs.groupByNel(_.name) + val dupTypes = localDefs + .groupByNel(_.name) .toList .traverse { case (n, tes) => if (tes.tail.isEmpty) sunit else { val dupRegions = tes.map(_.region) - SourceConverter.partial(SourceConverter.Duplication(n, SourceConverter.DupKind.TypeName, dupRegions), - ()) + SourceConverter.partial( + SourceConverter + .Duplication(n, SourceConverter.DupKind.TypeName, dupRegions), + () + ) } } @@ -910,11 +1132,13 @@ final class SourceConverter( // these are colliding constructors, but if they also collide on type // name we have already reported it above sunit - } - else { + } else { val dupRegions = tes.map(_._2.region) - SourceConverter.partial(SourceConverter.Duplication(n, SourceConverter.DupKind.Constructor, dupRegions), - ()) + SourceConverter.partial( + SourceConverter + .Duplication(n, SourceConverter.DupKind.Constructor, dupRegions), + () + ) } } @@ -931,9 +1155,7 @@ final class SourceConverter( toTypeEnv.map { p => importedTypeEnv ++ TypeEnv.fromParsed(p) } private def anonNameStrings(): Iterator[String] = - rankn.Type - .allBinders - .iterator + rankn.Type.allBinders.iterator .map(_.name) private def unusedNames(allNames: Bindable => Boolean): Iterator[Bindable] = @@ -941,13 +1163,14 @@ final class SourceConverter( .map(Identifier.Name(_)) .filterNot(allNames) - /** - * Externals are not permitted to be shadowed at the top level - */ - private def checkExternalDefShadowing(values: List[Statement.ValueStatement]): Result[Unit] = { + /** Externals are not permitted to be shadowed at the top level + */ + private def checkExternalDefShadowing( + values: List[Statement.ValueStatement] + ): Result[Unit] = { val extDefNames = - values.collect { - case ed@Statement.ExternalDef(name, _, _) => (name, ed.region) + values.collect { case ed @ Statement.ExternalDef(name, _, _) => + (name, ed.region) } val sunit = success(()) @@ -962,15 +1185,22 @@ final class SourceConverter( case NonEmptyList(_, Nil) => sunit case NonEmptyList((_, r1), (_, r2) :: rest) => SourceConverter.partial( - SourceConverter.Duplication(name, SourceConverter.DupKind.ExtDef, NonEmptyList(r1, r2 :: rest.map(_._2))), - ()) + SourceConverter.Duplication( + name, + SourceConverter.DupKind.ExtDef, + NonEmptyList(r1, r2 :: rest.map(_._2)) + ), + () + ) } } - def bindOrDef(s: Statement.ValueStatement): Option[Either[Statement.Bind, Statement.Def]] = + def bindOrDef( + s: Statement.ValueStatement + ): Option[Either[Statement.Bind, Statement.Def]] = s match { - case b@Statement.Bind(_) => Some(Left(b)) - case d@Statement.Def(_) => Some(Right(d)) + case b @ Statement.Bind(_) => Some(Left(b)) + case d @ Statement.Def(_) => Some(Right(d)) case Statement.ExternalDef(_, _, _) => None } @@ -982,16 +1212,15 @@ final class SourceConverter( val shadows = names.filter(extDefNamesSet) NonEmptyList.fromList(shadows) match { - case None => sunit + case None => sunit case Some(nel) => // we are shadowing SourceConverter.partial( - SourceConverter.ExtDefShadow( - SourceConverter.BindKind.Bind, - nel, - s.region), - ()) - } + SourceConverter + .ExtDefShadow(SourceConverter.BindKind.Bind, nel, s.region), + () + ) + } } dupRes *> values.traverse_(checkDefBind) @@ -999,9 +1228,9 @@ final class SourceConverter( } // Flatten pattern bindings out - private def bindingsDecl( - b: Pattern.Parsed, - decl: Declaration)(alloc: () => Bindable): NonEmptyList[(Bindable, Declaration)] = + private def bindingsDecl(b: Pattern.Parsed, decl: Declaration)( + alloc: () => Bindable + ): NonEmptyList[(Bindable, Declaration)] = b match { case Pattern.Var(nm) => NonEmptyList.one((nm, decl)) @@ -1024,8 +1253,7 @@ final class SourceConverter( if (decl.isCheap) { // no need to make a new var to point to a var (Nil, decl) - } - else { + } else { val ident = alloc() val v = Var(ident)(decl.region) ((ident, decl) :: Nil, v) @@ -1038,7 +1266,8 @@ final class SourceConverter( Match( RecursionKind.NonRecursive, rhsNB, - OptIndent.same(NonEmptyList.one((pat, resOI))))(decl.region) + OptIndent.same(NonEmptyList.one((pat, resOI))) + )(decl.region) } val tail: List[(Bindable, Declaration)] = @@ -1061,16 +1290,17 @@ final class SourceConverter( } } - private def parFold[F[_], S, A, B](s0: S, as: List[A])(fn: (S, A) => (S, F[B]))(implicit F: Applicative[F]): F[List[B]] = { + private def parFold[F[_], S, A, B](s0: S, as: List[A])( + fn: (S, A) => (S, F[B]) + )(implicit F: Applicative[F]): F[List[B]] = { val avec = as.toVector def loop(start: Int, end: Int, s: S): (S, F[Chain[B]]) = if (start >= end) (s, F.pure(Chain.empty)) else if (start == (end - 1)) { val (s1, fb) = fn(s, avec(start)) (s1, fb.map(Chain.one(_))) - } - else { - val mid = start + (end - start)/2 + } else { + val mid = start + (end - start) / 2 val (s1, f1) = loop(start, mid, s) val (s2, f2) = loop(mid, end, s1) (s2, F.map2(f1, f2)(_ ++ _)) @@ -1079,17 +1309,16 @@ final class SourceConverter( loop(0, avec.size, s0)._2.map(_.toList) } - /** - * Return the lets in order they appear - */ - private def toLets(stmts: Seq[Statement.ValueStatement]): Result[List[(Bindable, RecursionKind, Expr[Declaration])]] = { + /** Return the lets in order they appear + */ + private def toLets( + stmts: Seq[Statement.ValueStatement] + ): Result[List[(Bindable, RecursionKind, Expr[Declaration])]] = { import Statement._ val newName: () => Bindable = { lazy val allNames: Set[Bindable] = - stmts - .flatMap { v => v.names.iterator ++ v.allNames.iterator } - .toSet + stmts.flatMap { v => v.names.iterator ++ v.allNames.iterator }.toSet // Each time we need a name, we can call anonNames.next() // it is mutable, but in a limited scope @@ -1102,15 +1331,14 @@ final class SourceConverter( val flatList: List[(Bindable, RecursionKind, Flattened)] = stmts.toList.flatMap { - case d@Def(_) => + case d @ Def(_) => (d.defstatement.name, RecursionKind.Recursive, Left(d)) :: Nil case ExternalDef(_, _, _) => // we don't allow external defs to shadow at all, so skip it here Nil case Bind(BindingStatement(bound, decl, _)) => - bindingsDecl(bound, decl)(newName) - .toList - .map { case pair@(b, _) => + bindingsDecl(bound, decl)(newName).toList + .map { case pair @ (b, _) => (b, RecursionKind.NonRecursive, Right(pair)) } } @@ -1121,45 +1349,52 @@ final class SourceConverter( // TODO make a better name, close to the original, but also not colliding // by using idx val newNameV: Bindable = newName() - val fn: Flattened => Flattened = - { - case Left(d@Def(dstmt)) => - val d1 = if (dstmt.name === bind) dstmt.copy(name = newNameV) else dstmt - val res = - if (dstmt.args.flatten.iterator.flatMap(_.names).exists(_ == bind)) { - // the args are shadowing the binding, so we don't need to substitute - dstmt.result - } - else { - dstmt.result.map { body => - Declaration.substitute(bind, Var(newNameV)(body.region), body) match { - case Some(body1) => body1 - case None => - // $COVERAGE-OFF$ - throw new IllegalStateException("we know newName can't mask") - // $COVERAGE-ON$ - } + val fn: Flattened => Flattened = { + case Left(d @ Def(dstmt)) => + val d1 = + if (dstmt.name === bind) dstmt.copy(name = newNameV) else dstmt + val res = + if ( + dstmt.args.flatten.iterator.flatMap(_.names).exists(_ == bind) + ) { + // the args are shadowing the binding, so we don't need to substitute + dstmt.result + } else { + dstmt.result.map { body => + Declaration.substitute( + bind, + Var(newNameV)(body.region), + body + ) match { + case Some(body1) => body1 + case None => + // $COVERAGE-OFF$ + throw new IllegalStateException( + "we know newName can't mask" + ) + // $COVERAGE-ON$ } } - Left(Def(d1.copy(result = res))(d.region)) - case Right((b0, d)) => - // we don't need to update b0, we discard it anyway - Declaration.substitute(bind, Var(newNameV)(d.region), d) match { - case Some(d1) => Right((b0, d1)) - // $COVERAGE-OFF$ - case None => - throw new IllegalStateException("we know newName can't mask") - // $COVERAGE-ON$ } - } + Left(Def(d1.copy(result = res))(d.region)) + case Right((b0, d)) => + // we don't need to update b0, we discard it anyway + Declaration.substitute(bind, Var(newNameV)(d.region), d) match { + case Some(d1) => Right((b0, d1)) + // $COVERAGE-OFF$ + case None => + throw new IllegalStateException("we know newName can't mask") + // $COVERAGE-ON$ + } + } (newNameV, fn) } val withEx: List[Either[ExternalDef, Flattened]] = - stmts.collect { case e@ExternalDef(_, _, _) => Left(e) }.toList ::: + stmts.collect { case e @ ExternalDef(_, _, _) => Left(e) }.toList ::: flatIn.map { - case (b, _, Left(d@Def(dstmt))) => + case (b, _, Left(d @ Def(dstmt))) => Right(Left(Def(dstmt.copy(name = b))(d.region))) case (b, _, Right((_, d))) => Right(Right((b, d))) } @@ -1167,15 +1402,20 @@ final class SourceConverter( parFold(Set.empty[Bindable], withEx) { case (topBound, stmt) => stmt match { case Right(Right((nm, decl))) => - - val r = fromDecl(decl, Set.empty, topBound).map((nm, RecursionKind.NonRecursive, _) :: Nil) + val r = fromDecl(decl, Set.empty, topBound).map( + (nm, RecursionKind.NonRecursive, _) :: Nil + ) // make sure all the free types are Generic // we have to do this at the top level because in Declaration => Expr // we allow closing over type variables defined at a higher level - val r1 = r.map { exs => exs.map { case (n, r, e) => (n, r, Expr.quantifyFrees(e)) } } + val r1 = r.map { exs => + exs.map { case (n, r, e) => (n, r, Expr.quantifyFrees(e)) } + } (topBound + nm, r1) - case Right(Left(d @ Def(defstmt@DefStatement(_, _, argGroups, _, _)))) => + case Right( + Left(d @ Def(defstmt @ DefStatement(_, _, argGroups, _, _))) + ) => // using body for the outer here is a bummer, but not really a good outer otherwise val boundName = defstmt.name @@ -1186,15 +1426,20 @@ final class SourceConverter( toLambdaExpr[OptIndent[Declaration]]( defstmt, d.region, - success(defstmt.result.get))( - { (res: OptIndent[Declaration]) => - fromDecl(res.get, argGroups.flatten.iterator.flatMap(_.names).toSet + boundName, topBound1) - }) + success(defstmt.result.get) + )({ (res: OptIndent[Declaration]) => + fromDecl( + res.get, + argGroups.flatten.iterator.flatMap(_.names).toSet + boundName, + topBound1 + ) + }) val r = lam.map { (l: Expr[Declaration]) => // We rely on DefRecursionCheck to rule out bad recursions val rec = - if (UnusedLetCheck.freeBound(l).contains(boundName)) RecursionKind.Recursive + if (UnusedLetCheck.freeBound(l).contains(boundName)) + RecursionKind.Recursive else RecursionKind.NonRecursive // make sure all the free types are Generic // we have to do this at the top level because in Declaration => Expr @@ -1207,14 +1452,21 @@ final class SourceConverter( (topBound + n, success(Nil)) } }(SourceConverter.parallelIor) - .map(_.flatten) + .map(_.flatten) } - def toProgram(ss: List[Statement]): Result[Program[(TypeEnv[Kind.Arg], ParsedTypeEnv[Option[Kind.Arg]]), Expr[Declaration], List[Statement]]] = { + def toProgram( + ss: List[Statement] + ): Result[Program[(TypeEnv[Kind.Arg], ParsedTypeEnv[Option[Kind.Arg]]), Expr[ + Declaration + ], List[Statement]]] = { val stmts = Statement.valuesOf(ss).toList - stmts.collect { - case ed@Statement.ExternalDef(name, params, result) => - (params.traverse { p => toType(p._2, ed.region) }, toType(result, ed.region)) + stmts + .collect { case ed @ Statement.ExternalDef(name, params, result) => + ( + params.traverse { p => toType(p._2, ed.region) }, + toType(result, ed.region) + ) .flatMapN { (paramTypes, resType) => NonEmptyList.fromList(paramTypes) match { case None => success(resType) @@ -1224,7 +1476,10 @@ final class SourceConverter( case None => val invalid = rankn.Type.Fun(nel, resType) SourceConverter - .partial(SourceConverter.InvalidArity(nel.length, ed.region), invalid) + .partial( + SourceConverter.InvalidArity(nel.length, ed.region), + invalid + ) } } } @@ -1232,32 +1487,34 @@ final class SourceConverter( val freeVars = rankn.Type.freeTyVars(tpe :: Nil) // these vars were parsed so they are never skolem vars val freeBound = freeVars.map { - case b@rankn.Type.Var.Bound(_) => b - case s@rankn.Type.Var.Skolem(_, _, _) => + case b @ rankn.Type.Var.Bound(_) => b + case s @ rankn.Type.Var.Skolem(_, _, _) => // $COVERAGE-OFF$ this should be unreachable sys.error(s"invariant violation: parsed a skolem var: $s") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } // TODO: Kind support parsing kinds - val maybeForAll = rankn.Type.forAll(freeBound.map { n => (n, Kind.Type) }, tpe) + val maybeForAll = + rankn.Type.forAll(freeBound.map { n => (n, Kind.Type) }, tpe) (name, maybeForAll) } - } - // TODO: we could implement Iterable[Ior[A, B]] => Ior[A, Iterble[B]] - // where we drop all total failures in order to make more progress - .sequence - .flatMap { exts => - val pte1 = toTypeEnv.map { p => - exts.foldLeft(p) { case (pte, (name, tpe)) => - pte.addExternalValue(thisPackage, name, tpe) - } } + // TODO: we could implement Iterable[Ior[A, B]] => Ior[A, Iterble[B]] + // where we drop all total failures in order to make more progress + .sequence + .flatMap { exts => + val pte1 = toTypeEnv.map { p => + exts.foldLeft(p) { case (pte, (name, tpe)) => + pte.addExternalValue(thisPackage, name, tpe) + } + } - implicit val parallel = SourceConverter.parallelIor - (checkExternalDefShadowing(stmts), toLets(stmts), pte1).mapN { (_, binds, pte1) => - Program((importedTypeEnv, pte1), binds, exts.map(_._1).toList, ss) + implicit val parallel = SourceConverter.parallelIor + (checkExternalDefShadowing(stmts), toLets(stmts), pte1).mapN { + (_, binds, pte1) => + Program((importedTypeEnv, pte1), binds, exts.map(_._1).toList, ss) + } } - } } } @@ -1267,7 +1524,8 @@ object SourceConverter { def success[A](a: A): Result[A] = Ior.Right(a) val successUnit: Result[Unit] = success(()) - def partial[A](err: Error, a: A): Result[A] = Ior.Both(NonEmptyChain.one(err), a) + def partial[A](err: Error, a: A): Result[A] = + Ior.Both(NonEmptyChain.one(err), a) def failure[A](err: Error): Result[A] = Ior.Left(NonEmptyChain.one(err)) def addError[A](r: Result[A], err: Error): Result[A] = @@ -1277,114 +1535,124 @@ object SourceConverter { private val parallelIor: Applicative[Result] = Ior.catsDataParallelForIor[NonEmptyChain[Error]].applicative - def toProgram( - thisPackage: PackageName, - imports: List[Import[PackageName, NonEmptyList[Referant[Kind.Arg]]]], - stmts: List[Statement]): Result[Program[(TypeEnv[Kind.Arg], ParsedTypeEnv[Option[Kind.Arg]]), Expr[Declaration], List[Statement]]] = - (new SourceConverter(thisPackage, imports, Statement.definitionsOf(stmts).toList)).toProgram(stmts) + def toProgram( + thisPackage: PackageName, + imports: List[Import[PackageName, NonEmptyList[Referant[Kind.Arg]]]], + stmts: List[Statement] + ): Result[Program[(TypeEnv[Kind.Arg], ParsedTypeEnv[Option[Kind.Arg]]), Expr[ + Declaration + ], List[Statement]]] = + (new SourceConverter( + thisPackage, + imports, + Statement.definitionsOf(stmts).toList + )).toProgram(stmts) private def concat[A](ls: List[A], tail: NonEmptyList[A]): NonEmptyList[A] = ls match { - case Nil => tail + case Nil => tail case h :: t => NonEmptyList(h, t ::: tail.toList) } - /** - * For all duplicate binds, for all but the final - * value, rename them - */ - def makeLetsUnique[D]( - lets: List[(Bindable, RecursionKind, D)])( - newName: (Bindable, Int) => (Bindable, D => D)): List[(Bindable, RecursionKind, D)] = - NonEmptyList.fromList(lets) match { - case None => Nil - case Some(nelets) => - // there is at least 1 let, but maybe no duplicates - val dups: Map[Bindable, Int] = - nelets.foldLeft(Map.empty[Bindable, Int]) { - case (bound, (b, _, _)) => - bound.get(b) match { - case Some(c) => bound.updated(b, c + 1) - case None => bound.updated(b, 1) - } + /** For all duplicate binds, for all but the final value, rename them + */ + def makeLetsUnique[D](lets: List[(Bindable, RecursionKind, D)])( + newName: (Bindable, Int) => (Bindable, D => D) + ): List[(Bindable, RecursionKind, D)] = + NonEmptyList.fromList(lets) match { + case None => Nil + case Some(nelets) => + // there is at least 1 let, but maybe no duplicates + val dups: Map[Bindable, Int] = + nelets + .foldLeft(Map.empty[Bindable, Int]) { case (bound, (b, _, _)) => + bound.get(b) match { + case Some(c) => bound.updated(b, c + 1) + case None => bound.updated(b, 1) + } } .filter { case (_, v) => v > 1 } - if (dups.isEmpty) { - // no duplicated top level names - lets - } - else { - // we rename all but the last name for each duplicate - type BRD = (Bindable, RecursionKind, D) - - /* - * Invariant, lets.exists(_._1 == name) == true - * if this is false, this method will throw - */ - @annotation.tailrec - def renameUntilNext(name: Bindable, lets: NonEmptyList[BRD], acc: List[BRD])(fn: D => D): NonEmptyList[BRD] = { - // note this is a total match: - val NonEmptyList(head @ (b, r, d), tail) = lets - - if (b == name) { - val head1 = - if (r.isRecursive) { - // the new b is in scope right away - head - } - else { - // the old b1 is in scope for this one - (b, r, fn(d)) - } - NonEmptyList(head1, acc).reverse.concat(tail) - } - else { - // if b != name, then that implies there is - // at least one item in the tail with b, - // so tail cannot be empty - val netail = NonEmptyList.fromListUnsafe(tail) - renameUntilNext(name, netail, (b, r, fn(d)) :: acc)(fn) - } + if (dups.isEmpty) { + // no duplicated top level names + lets + } else { + // we rename all but the last name for each duplicate + type BRD = (Bindable, RecursionKind, D) + + /* + * Invariant, lets.exists(_._1 == name) == true + * if this is false, this method will throw + */ + @annotation.tailrec + def renameUntilNext( + name: Bindable, + lets: NonEmptyList[BRD], + acc: List[BRD] + )(fn: D => D): NonEmptyList[BRD] = { + // note this is a total match: + val NonEmptyList(head @ (b, r, d), tail) = lets + + if (b == name) { + val head1 = + if (r.isRecursive) { + // the new b is in scope right away + head + } else { + // the old b1 is in scope for this one + (b, r, fn(d)) + } + NonEmptyList(head1, acc).reverse.concat(tail) + } else { + // if b != name, then that implies there is + // at least one item in the tail with b, + // so tail cannot be empty + val netail = NonEmptyList.fromListUnsafe(tail) + renameUntilNext(name, netail, (b, r, fn(d)) :: acc)(fn) } + } - @annotation.tailrec - def loop(lets: NonEmptyList[BRD], state: Map[Bindable, (Int, Int)], acc: List[BRD]): NonEmptyList[BRD] = { - val head = lets.head - NonEmptyList.fromList(lets.tail) match { - case Some(netail) => - val (b, r, d) = head - state.get(b) match { - case Some((cnt, sz)) if cnt < (sz - 1) => - val newState = state.updated(b, (cnt + 1, sz)) - // we have to rename until the next bind - val (b1, renamer) = newName(b, cnt) - val d1 = - if (r.isRecursive) renamer(d) - else d - - val head1 = (b1, r, d1) - // since cnt < (sz - 1) we know that - // b must occur at least once in netail - val tail1 = renameUntilNext(b, netail, Nil)(renamer) - loop(tail1, newState, head1 :: acc) - case _ => - // this is the last one or not a duplicate, we don't change it - loop(netail, state, head :: acc) - } - case None => - // the last one is never renamed - NonEmptyList(head, acc).reverse - } + @annotation.tailrec + def loop( + lets: NonEmptyList[BRD], + state: Map[Bindable, (Int, Int)], + acc: List[BRD] + ): NonEmptyList[BRD] = { + val head = lets.head + NonEmptyList.fromList(lets.tail) match { + case Some(netail) => + val (b, r, d) = head + state.get(b) match { + case Some((cnt, sz)) if cnt < (sz - 1) => + val newState = state.updated(b, (cnt + 1, sz)) + // we have to rename until the next bind + val (b1, renamer) = newName(b, cnt) + val d1 = + if (r.isRecursive) renamer(d) + else d + + val head1 = (b1, r, d1) + // since cnt < (sz - 1) we know that + // b must occur at least once in netail + val tail1 = renameUntilNext(b, netail, Nil)(renamer) + loop(tail1, newState, head1 :: acc) + case _ => + // this is the last one or not a duplicate, we don't change it + loop(netail, state, head :: acc) + } + case None => + // the last one is never renamed + NonEmptyList(head, acc).reverse } + } - // there are duplicates - val dupState: Map[Bindable, (Int, Int)] = - dups.iterator.map { case (k, sz) => (k, (0, sz)) }.toMap + // there are duplicates + val dupState: Map[Bindable, (Int, Int)] = + dups.iterator.map { case (k, sz) => (k, (0, sz)) }.toMap - loop(nelets, dupState, Nil).toList - } + loop(nelets, dupState, Nil).toList } + } sealed abstract class Error { def region: Region @@ -1402,7 +1670,11 @@ object SourceConverter { final case object Bind extends BindKind("bind") } - final case class ExtDefShadow(kind: BindKind, names: NonEmptyList[Bindable], region: Region) extends Error { + final case class ExtDefShadow( + kind: BindKind, + names: NonEmptyList[Bindable], + region: Region + ) extends Error { def message = { val ns = names.toList.iterator.map(_.sourceCodeRepr).mkString(", ") s"${kind.asString} names $ns shadow external def" @@ -1416,13 +1688,21 @@ object SourceConverter { case object Constructor extends DupKind("constructor") } - final case class Duplication(name: Identifier, kind: DupKind, duplicates: NonEmptyList[Region]) extends Error { + final case class Duplication( + name: Identifier, + kind: DupKind, + duplicates: NonEmptyList[Region] + ) extends Error { def region = duplicates.head def message = s"${kind.asString}: ${name.sourceCodeRepr} defined multiple times" } - final case class PatternShadow(names: NonEmptyList[Bindable], pattern: Pattern.Parsed, region: Region) extends Error { + final case class PatternShadow( + names: NonEmptyList[Bindable], + pattern: Pattern.Parsed, + region: Region + ) extends Error { def message = { val str = names.toList.map(_.sourceCodeRepr).mkString(", ") "repeated bindings in pattern: " + str @@ -1436,7 +1716,8 @@ object SourceConverter { final case class Pat(toPattern: Pattern.Parsed) extends ConstructorSyntax { def toDoc = Document[Pattern.Parsed].document(toPattern) } - final case class RecCons(toDeclaration: Declaration.RecordConstructor) extends ConstructorSyntax { + final case class RecCons(toDeclaration: Declaration.RecordConstructor) + extends ConstructorSyntax { def toDoc = toDeclaration.toDoc } @@ -1447,10 +1728,19 @@ object SourceConverter { RecCons(c) } - final case class UnknownConstructor(name: Constructor, syntax: ConstructorSyntax, region: Region) extends ConstructorError { + final case class UnknownConstructor( + name: Constructor, + syntax: ConstructorSyntax, + region: Region + ) extends ConstructorError { def message = { val maybeDoc = syntax match { - case ConstructorSyntax.Pat(Pattern.PositionalStruct(Pattern.StructKind.Named(n, Pattern.StructKind.Style.TupleLike), Nil)) if n == name => + case ConstructorSyntax.Pat( + Pattern.PositionalStruct( + Pattern.StructKind.Named(n, Pattern.StructKind.Style.TupleLike), + Nil + ) + ) if n == name => // the pattern is just name Doc.empty case _ => @@ -1459,28 +1749,59 @@ object SourceConverter { (Doc.text(s"unknown constructor ${name.asString}") + maybeDoc).render(80) } } - final case class InvalidArgCount(name: Constructor, syntax: ConstructorSyntax, argCount: Int, expected: Int, region: Region) extends ConstructorError { + final case class InvalidArgCount( + name: Constructor, + syntax: ConstructorSyntax, + argCount: Int, + expected: Int, + region: Region + ) extends ConstructorError { def message = - (Doc.text(s"invalid argument count in ${name.asString}, found $argCount expected $expected") + Doc.lineOrSpace + syntax.toDoc).render(80) + (Doc.text( + s"invalid argument count in ${name.asString}, found $argCount expected $expected" + ) + Doc.lineOrSpace + syntax.toDoc).render(80) } - final case class MissingArg(name: Constructor, syntax: ConstructorSyntax, present: SortedSet[Bindable], missing: Bindable, region: Region) extends ConstructorError { + final case class MissingArg( + name: Constructor, + syntax: ConstructorSyntax, + present: SortedSet[Bindable], + missing: Bindable, + region: Region + ) extends ConstructorError { def message = - (Doc.text(s"missing field ${missing.asString} in ${name.asString}") + Doc.lineOrSpace + syntax.toDoc).render(80) + (Doc.text( + s"missing field ${missing.asString} in ${name.asString}" + ) + Doc.lineOrSpace + syntax.toDoc).render(80) } - final case class UnexpectedField(name: Constructor, syntax: ConstructorSyntax, unexpected: NonEmptyList[Bindable], expected: List[Bindable], region: Region) extends ConstructorError { + final case class UnexpectedField( + name: Constructor, + syntax: ConstructorSyntax, + unexpected: NonEmptyList[Bindable], + expected: List[Bindable], + region: Region + ) extends ConstructorError { def message = { val plural = if (unexpected.tail.isEmpty) "field" else "fields" - val unexDoc = Doc.intercalate(Doc.comma + Doc.lineOrSpace, unexpected.toList.map { b => Doc.text(b.asString) }) - val exDoc = Doc.intercalate(Doc.comma + Doc.lineOrSpace, expected.map { b => Doc.text(b.asString) }) + val unexDoc = Doc.intercalate( + Doc.comma + Doc.lineOrSpace, + unexpected.toList.map { b => Doc.text(b.asString) } + ) + val exDoc = Doc.intercalate( + Doc.comma + Doc.lineOrSpace, + expected.map { b => Doc.text(b.asString) } + ) (Doc.text(s"unexpected $plural: ") + unexDoc + Doc.lineOrSpace + - Doc.text(s"in ${name.asString}, expected: ") + exDoc + Doc.lineOrSpace + syntax.toDoc).render(80) - } + Doc.text( + s"in ${name.asString}, expected: " + ) + exDoc + Doc.lineOrSpace + syntax.toDoc).render(80) + } } final case class InvalidTypeParameters( - declaredParams: NonEmptyList[(TypeRef.TypeVar, Option[Kind.Arg])], - discoveredTypes: List[Type.Var.Bound], - statement: TypeDefinitionStatement) extends Error { + declaredParams: NonEmptyList[(TypeRef.TypeVar, Option[Kind.Arg])], + discoveredTypes: List[Type.Var.Bound], + statement: TypeDefinitionStatement + ) extends Error { def region = statement.region def message = { @@ -1488,45 +1809,61 @@ object SourceConverter { l.iterator.map(_.name).mkString("[", ", ", "]") val decl = - TypeRef.docTypeArgs(declaredParams.toList) { - case None => Doc.empty - case Some(ka) => Doc.text(": ") + Kind.argDoc(ka) - }.renderTrim(80) + TypeRef + .docTypeArgs(declaredParams.toList) { + case None => Doc.empty + case Some(ka) => Doc.text(": ") + Kind.argDoc(ka) + } + .renderTrim(80) val disc = tstr(discoveredTypes) s"${statement.name.asString} found declared: $decl, not a superset of $disc" } } final case class InvalidDefTypeParameters[B]( - declaredParams: NonEmptyList[(TypeRef.TypeVar, Option[Kind])], - free: List[Type.Var.Bound], - defstmt: DefStatement[Pattern.Parsed, B], - region: Region) extends Error { + declaredParams: NonEmptyList[(TypeRef.TypeVar, Option[Kind])], + free: List[Type.Var.Bound], + defstmt: DefStatement[Pattern.Parsed, B], + region: Region + ) extends Error { def message = { def tstr(l: List[Type.Var.Bound]): String = l.iterator.map(_.name).mkString("[", ", ", "]") - val decl = TypeRef.docTypeArgs(declaredParams.toList) { - case None => Doc.empty - case Some(k) => Doc.text(": ") + Kind.toDoc(k) - }.renderTrim(80) + val decl = TypeRef + .docTypeArgs(declaredParams.toList) { + case None => Doc.empty + case Some(k) => Doc.text(": ") + Kind.toDoc(k) + } + .renderTrim(80) val freeStr = tstr(free) s"${defstmt.name.asString} found declared types: $decl, not a subset of $freeStr" } } - final case class UnknownTypeName(tpe: Constructor, region: Region) extends Error { + final case class UnknownTypeName(tpe: Constructor, region: Region) + extends Error { def message = s"unknown type: ${tpe.asString}" } final case class InvalidArity(size: Int, region: Region) extends Error { - def message = s"invalid function arguments = $size, maximum = ${rankn.Type.FnType.MaxSize}" + def message = + s"invalid function arguments = $size, maximum = ${rankn.Type.FnType.MaxSize}" } - final case class TooManyConstructorArgs(name: Constructor, argCount: Int, max: Int, region: Region) extends Error { + final case class TooManyConstructorArgs( + name: Constructor, + argCount: Int, + max: Int, + region: Region + ) extends Error { def message = - Doc.text(s"invalid argument count in constructor for ${name.asString} found $argCount maximum allowed $max").render(80) + Doc + .text( + s"invalid argument count in constructor for ${name.asString} found $argCount maximum allowed $max" + ) + .render(80) } } diff --git a/core/src/main/scala/org/bykn/bosatsu/Statement.scala b/core/src/main/scala/org/bykn/bosatsu/Statement.scala index 4469a94b9..16fb21833 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Statement.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Statement.scala @@ -1,10 +1,10 @@ package org.bykn.bosatsu -import Parser.{ Combinators, Indy, maybeSpace, keySpace, toEOL } +import Parser.{Combinators, Indy, maybeSpace, keySpace, toEOL} import cats.data.NonEmptyList import cats.implicits._ import cats.parse.{Parser0 => P0, Parser => P} -import org.typelevel.paiges.{ Doc, Document } +import org.typelevel.paiges.{Doc, Document} import scala.collection.immutable.SortedSet import Indy.IndyMethods @@ -12,10 +12,9 @@ import Identifier.{Bindable, Constructor} sealed abstract class Statement { - /** - * This describes the region of the current statement, not the entire linked list - * of statements - */ + /** This describes the region of the current statement, not the entire linked + * list of statements + */ def region: Region def replaceRegions(r: Region): Statement = { @@ -46,14 +45,12 @@ sealed abstract class Statement { sealed abstract class TypeDefinitionStatement extends Statement { import Statement.{Struct, Enum, ExternalStruct} - /** - * This is the name of the type being defined - */ + /** This is the name of the type being defined + */ def name: Constructor - /** - * here are the names of the constructors for this type - */ + /** here are the names of the constructors for this type + */ def constructors: List[Constructor] = this match { case Struct(nm, _, _) => nm :: Nil @@ -65,48 +62,53 @@ sealed abstract class TypeDefinitionStatement extends Statement { object Statement { - def definitionsOf(stmts: Iterable[Statement]): LazyList[TypeDefinitionStatement] = - stmts.iterator.collect { case tds: TypeDefinitionStatement => tds }.to(LazyList) + def definitionsOf( + stmts: Iterable[Statement] + ): LazyList[TypeDefinitionStatement] = + stmts.iterator + .collect { case tds: TypeDefinitionStatement => tds } + .to(LazyList) def valuesOf(stmts: Iterable[Statement]): LazyList[ValueStatement] = stmts.iterator.collect { case vs: ValueStatement => vs }.to(LazyList) - /** - * These introduce new values into scope - */ + /** These introduce new values into scope + */ sealed abstract class ValueStatement extends Statement { - /** - * All the names that are bound by this statement - */ + + /** All the names that are bound by this statement + */ def names: List[Bindable] = this match { - case Bind(BindingStatement(bound, _, _)) => bound.names // TODO Keep identifiers - case Def(defstatement) => defstatement.name :: Nil + case Bind(BindingStatement(bound, _, _)) => + bound.names // TODO Keep identifiers + case Def(defstatement) => defstatement.name :: Nil case ExternalDef(name, _, _) => name :: Nil } - /** - * These are all the free bindable names in the right hand side - * of this binding - */ + /** These are all the free bindable names in the right hand side of this + * binding + */ def freeVars: SortedSet[Bindable] = this match { case Bind(BindingStatement(_, decl, _)) => decl.freeVars case Def(defstatement) => val innerFrees = defstatement.result.get.freeVars // but the def name and, args shadow - (innerFrees - defstatement.name) -- defstatement.args.toList.flatMap(_.patternNames) + (innerFrees - defstatement.name) -- defstatement.args.toList.flatMap( + _.patternNames + ) case ExternalDef(_, _, _) => SortedSet.empty } - /** - * These are all the bindings, free or not, in this Statement - */ + /** These are all the bindings, free or not, in this Statement + */ def allNames: SortedSet[Bindable] = { this match { case Bind(BindingStatement(pat, decl, _)) => decl.allNames ++ pat.names case Def(defstatement) => - (defstatement.result.get.allNames + defstatement.name) ++ defstatement.args.toList.flatMap(_.patternNames) + (defstatement.result.get.allNames + defstatement.name) ++ defstatement.args.toList + .flatMap(_.patternNames) case ExternalDef(name, _, _) => SortedSet(name) } } @@ -114,180 +116,236 @@ object Statement { ////// // All the ValueStatements, which set up new bindings in the order they appear in the file - /////. - case class Bind(bind: BindingStatement[Pattern.Parsed, Declaration.NonBinding, Unit])(val region: Region) extends ValueStatement - case class Def(defstatement: DefStatement[Pattern.Parsed, OptIndent[Declaration]])(val region: Region) extends ValueStatement - case class ExternalDef(name: Bindable, params: List[(Bindable, TypeRef)], result: TypeRef)(val region: Region) extends ValueStatement + ///// . + case class Bind( + bind: BindingStatement[Pattern.Parsed, Declaration.NonBinding, Unit] + )(val region: Region) + extends ValueStatement + case class Def( + defstatement: DefStatement[Pattern.Parsed, OptIndent[Declaration]] + )(val region: Region) + extends ValueStatement + case class ExternalDef( + name: Bindable, + params: List[(Bindable, TypeRef)], + result: TypeRef + )(val region: Region) + extends ValueStatement ////// // TypeDefinitionStatement types: ////// - case class Enum(name: Constructor, - typeArgs: Option[NonEmptyList[(TypeRef.TypeVar, Option[Kind.Arg])]], - items: OptIndent[NonEmptyList[(Constructor, List[(Bindable, Option[TypeRef])])]] - )(val region: Region) extends TypeDefinitionStatement - case class ExternalStruct(name: Constructor, typeArgs: List[(TypeRef.TypeVar, Option[Kind.Arg])])(val region: Region) extends TypeDefinitionStatement - case class Struct(name: Constructor, - typeArgs: Option[NonEmptyList[(TypeRef.TypeVar, Option[Kind.Arg])]], - args: List[(Bindable, Option[TypeRef])])(val region: Region) extends TypeDefinitionStatement + case class Enum( + name: Constructor, + typeArgs: Option[NonEmptyList[(TypeRef.TypeVar, Option[Kind.Arg])]], + items: OptIndent[ + NonEmptyList[(Constructor, List[(Bindable, Option[TypeRef])])] + ] + )(val region: Region) + extends TypeDefinitionStatement + case class ExternalStruct( + name: Constructor, + typeArgs: List[(TypeRef.TypeVar, Option[Kind.Arg])] + )(val region: Region) + extends TypeDefinitionStatement + case class Struct( + name: Constructor, + typeArgs: Option[NonEmptyList[(TypeRef.TypeVar, Option[Kind.Arg])]], + args: List[(Bindable, Option[TypeRef])] + )(val region: Region) + extends TypeDefinitionStatement //// // These have no effect on the semantics of the Statement linked list //// - case class PaddingStatement(padding: Padding[Unit])(val region: Region) extends Statement - case class Comment(comment: CommentStatement[Unit])(val region: Region) extends Statement + case class PaddingStatement(padding: Padding[Unit])(val region: Region) + extends Statement + case class Comment(comment: CommentStatement[Unit])(val region: Region) + extends Statement // Parse a single item final val parser1: P[Statement] = { - import Declaration.NonBinding + import Declaration.NonBinding val bindingLike: Indy[(Pattern.Parsed, OptIndent[NonBinding])] = { val pat = Pattern.bindParser val patPart = pat <* (maybeSpace *> Declaration.eqP *> maybeSpace) // allow = to be like a block, we can continue on the next line indented - OptIndent.blockLike(Indy.lift(patPart), Declaration.nonBindingParser, P.unit) + OptIndent.blockLike( + Indy.lift(patPart), + Declaration.nonBindingParser, + P.unit + ) } - val bindingP: P[Statement] = - (bindingLike("") <* toEOL) - .region - .map { case (region, (pat, value)) => - Bind(BindingStatement(pat, value.get, ()))(region) - } - - val paddingSP: P[Statement] = - Padding - .nonEmptyParser - .region - .map { case (region, p) => PaddingStatement(p)(region) } - - val commentP: P[Statement] = - CommentStatement.parser(_ => P.unit).region - .map { case (region, cs) => Comment(cs)(region) }.run("") - - val defBody = maybeSpace.with1 *> OptIndent.indy(Declaration.parser).run("") - val defP: P[Statement] = - DefStatement.parser(Pattern.bindParser, defBody <* toEOL).region + val bindingP: P[Statement] = + (bindingLike("") <* toEOL).region + .map { case (region, (pat, value)) => + Bind(BindingStatement(pat, value.get, ()))(region) + } + + val paddingSP: P[Statement] = + Padding.nonEmptyParser.region + .map { case (region, p) => PaddingStatement(p)(region) } + + val commentP: P[Statement] = + CommentStatement + .parser(_ => P.unit) + .region + .map { case (region, cs) => Comment(cs)(region) } + .run("") + + val defBody = maybeSpace.with1 *> OptIndent.indy(Declaration.parser).run("") + val defP: P[Statement] = + DefStatement + .parser(Pattern.bindParser, defBody <* toEOL) + .region .map { case (region, DefStatement(nm, ta, args, ret, body)) => Def(DefStatement(nm, ta, args, ret, body))(region) } - val argParser: P[(Bindable, Option[TypeRef])] = - Identifier.bindableParser ~ TypeRef.annotationParser.? - - val structKey = keySpace("struct") - - val typeParams: P[NonEmptyList[(TypeRef.TypeVar, Option[Kind.Arg])]] = { - val kindAnnot: P[Kind.Arg] = - (maybeSpace.soft.with1 *> (P.char(':') *> maybeSpace *> Kind.paramKindParser)) - - TypeRef.typeParams(kindAnnot.?) - } - val external = { - val externalStruct = - (structKey *> (Identifier.consParser ~ Parser.nonEmptyListToList(typeParams)).region <* toEOL) - .map { - case (region, (name, tva)) => ExternalStruct(name, tva)(region) - } - - val argParser: P[(Bindable, TypeRef)] = Identifier.bindableParser ~ TypeRef.annotationParser - - val externalDef = { - - val args = P.char('(') *> maybeSpace *> argParser.nonEmptyList <* maybeSpace <* P.char(')') - - val result = maybeSpace.with1 *> P.string("->") *> maybeSpace *> TypeRef.parser - - (((keySpace("def") *> Identifier.bindableParser ~ args ~ result).region) <* toEOL) - .map { - case (region, ((name, args), resType)) => - ExternalDef(name, args.toList, resType)(region) - } - } - - val externalVal = - (argParser <* toEOL) - .region - .map { case (region, (name, resType)) => - ExternalDef(name, Nil, resType)(region) - } - - keySpace("external") *> P.oneOf(externalStruct :: externalDef :: externalVal :: Nil) - } - - val struct = - ((structKey *> Identifier.consParser ~ typeParams.? ~ Parser.nonEmptyListToList(argParser.parensLines1Cut)).region <* toEOL) - .map { case (region, ((name, typeArgs), argsList)) => - Struct(name, typeArgs, argsList)(region) - } - - val enumP = { - val constructorP = - (Identifier.consParser ~ argParser.parensLines1Cut.?) - .map { - case (n, None) => (n, Nil) - case (n, Some(args)) => (n, args.toList) - } - - val sep = (Indy.lift(P.char(',') <* maybeSpace)) - .combineK(Indy.toEOLIndent) - .void - - val variants = Indy.lift(constructorP <* maybeSpace).nonEmptyList(sep) - - val nameVars = - OptIndent.block( - Indy.lift(keySpace("enum") *> Identifier.consParser ~ (typeParams.?)), - variants - ) - .run("") - .region - - (nameVars <* toEOL) - .map { case (region, ((ename, typeArgs), vars)) => - Enum(ename, typeArgs, vars)(region) - } - } - - // bindingP should come last so there is no ambiguity about identifiers - P.oneOf(commentP :: paddingSP :: defP :: struct :: enumP :: external :: bindingP :: Nil) + val argParser: P[(Bindable, Option[TypeRef])] = + Identifier.bindableParser ~ TypeRef.annotationParser.? + + val structKey = keySpace("struct") + + val typeParams: P[NonEmptyList[(TypeRef.TypeVar, Option[Kind.Arg])]] = { + val kindAnnot: P[Kind.Arg] = + (maybeSpace.soft.with1 *> (P.char( + ':' + ) *> maybeSpace *> Kind.paramKindParser)) + + TypeRef.typeParams(kindAnnot.?) + } + val external = { + val externalStruct = + (structKey *> (Identifier.consParser ~ Parser.nonEmptyListToList( + typeParams + )).region <* toEOL) + .map { case (region, (name, tva)) => + ExternalStruct(name, tva)(region) + } + + val argParser: P[(Bindable, TypeRef)] = + Identifier.bindableParser ~ TypeRef.annotationParser + + val externalDef = { + + val args = + P.char('(') *> maybeSpace *> argParser.nonEmptyList <* maybeSpace <* P + .char(')') + + val result = + maybeSpace.with1 *> P.string("->") *> maybeSpace *> TypeRef.parser + + (((keySpace( + "def" + ) *> Identifier.bindableParser ~ args ~ result).region) <* toEOL) + .map { case (region, ((name, args), resType)) => + ExternalDef(name, args.toList, resType)(region) + } + } + + val externalVal = + (argParser <* toEOL).region + .map { case (region, (name, resType)) => + ExternalDef(name, Nil, resType)(region) + } + + keySpace("external") *> P.oneOf( + externalStruct :: externalDef :: externalVal :: Nil + ) + } + + val struct = + ((structKey *> Identifier.consParser ~ typeParams.? ~ Parser + .nonEmptyListToList(argParser.parensLines1Cut)).region <* toEOL) + .map { case (region, ((name, typeArgs), argsList)) => + Struct(name, typeArgs, argsList)(region) + } + + val enumP = { + val constructorP = + (Identifier.consParser ~ argParser.parensLines1Cut.?) + .map { + case (n, None) => (n, Nil) + case (n, Some(args)) => (n, args.toList) + } + + val sep = (Indy + .lift(P.char(',') <* maybeSpace)) + .combineK(Indy.toEOLIndent) + .void + + val variants = Indy.lift(constructorP <* maybeSpace).nonEmptyList(sep) + + val nameVars = + OptIndent + .block( + Indy.lift( + keySpace("enum") *> Identifier.consParser ~ (typeParams.?) + ), + variants + ) + .run("") + .region + + (nameVars <* toEOL) + .map { case (region, ((ename, typeArgs), vars)) => + Enum(ename, typeArgs, vars)(region) + } + } + + // bindingP should come last so there is no ambiguity about identifiers + P.oneOf( + commentP :: paddingSP :: defP :: struct :: enumP :: external :: bindingP :: Nil + ) } - /** - * This parses the *rest* of the string (it must end with End) - */ + /** This parses the *rest* of the string (it must end with End) + */ val parser: P0[List[Statement]] = parser1.rep0 <* Parser.maybeSpacesAndLines <* P.end - private def constructor(name: Constructor, taDoc: Doc, args: List[(Bindable, Option[TypeRef])]): Doc = + private def constructor( + name: Constructor, + taDoc: Doc, + args: List[(Bindable, Option[TypeRef])] + ): Doc = Document[Identifier].document(name) + taDoc + - (if (args.nonEmpty) { Doc.char('(') + Doc.intercalate(Doc.text(", "), args.toList.map(TypeRef.argDoc[Bindable] _)) + Doc.char(')') } - else Doc.empty) + (if (args.nonEmpty) { + Doc.char('(') + Doc.intercalate( + Doc.text(", "), + args.toList.map(TypeRef.argDoc[Bindable] _) + ) + Doc.char(')') + } else Doc.empty) private val colonSpace = Doc.text(": ") - private implicit val dunit: Document[Unit] = Document.instance[Unit](_ => Doc.empty) + private implicit val dunit: Document[Unit] = + Document.instance[Unit](_ => Doc.empty) private val optKindArgs: Document[Option[Kind.Arg]] = Document { - case None => Doc.empty + case None => Doc.empty case Some(ka) => colonSpace + Kind.argDoc(ka) } implicit lazy val document: Document[Statement] = { - val db = Document[BindingStatement[Pattern.Parsed, Declaration.NonBinding, Unit]] + val db = + Document[BindingStatement[Pattern.Parsed, Declaration.NonBinding, Unit]] val dc = Document[CommentStatement[Unit]] implicit val pair: Document[OptIndent[Declaration]] = - Document.instance[OptIndent[Declaration]] { - body => - body.sepDoc + + Document.instance[OptIndent[Declaration]] { body => + body.sepDoc + OptIndent.document(Declaration.document).document(body) } val dd = DefStatement.document[Pattern.Parsed, OptIndent[Declaration]] - implicit val consDoc = Document.instance[(Constructor, List[(Bindable, Option[TypeRef])])] { - case (nm, parts) => constructor(nm, Doc.empty, parts) - } + implicit val consDoc = + Document.instance[(Constructor, List[(Bindable, Option[TypeRef])])] { + case (nm, parts) => constructor(nm, Doc.empty, parts) + } Document.instance[Statement] { case Bind(bs) => @@ -302,14 +360,13 @@ object Statement { Padding.document[Unit].document(p) case Struct(nm, typeArgs, args) => val taDoc = typeArgs match { - case None => Doc.empty + case None => Doc.empty case Some(ta) => TypeRef.docTypeArgs(ta.toList)(optKindArgs.document) } Doc.text("struct ") + constructor(nm, taDoc, args) + Doc.line case Enum(nm, typeArgs, parts) => - val (colonSep, itemSep) = parts match { - case OptIndent.SameLine(_) => (Doc.space, Doc.text(", ")) + case OptIndent.SameLine(_) => (Doc.space, Doc.text(", ")) case OptIndent.NotSameLine(_) => (Doc.empty, Doc.line) } @@ -321,30 +378,40 @@ object Statement { val indentedCons = OptIndent.document(neDoc(consDoc)).document(parts) val taDoc = typeArgs match { - case None => Doc.empty + case None => Doc.empty case Some(ta) => TypeRef.docTypeArgs(ta.toList)(optKindArgs.document) } - Doc.text("enum ") + Document[Constructor].document(nm) + taDoc + Doc.char(':') + + Doc.text("enum ") + Document[Constructor].document(nm) + taDoc + Doc + .char(':') + colonSep + indentedCons + Doc.line case ExternalDef(name, Nil, res) => - Doc.text("external ") + Document[Bindable].document(name) + Doc.text(": ") + res.toDoc + Doc.line + Doc.text("external ") + Document[Bindable].document(name) + Doc.text( + ": " + ) + res.toDoc + Doc.line case ExternalDef(name, args, res) => val argDoc = { - val da = Doc.intercalate(Doc.text(", "), args.map { case (n, tr) => - Document[Bindable].document(n) + Doc.text(": ") + tr.toDoc - }) + val da = Doc.intercalate( + Doc.text(", "), + args.map { case (n, tr) => + Document[Bindable].document(n) + Doc.text(": ") + tr.toDoc + } + ) Doc.char('(') + da + Doc.char(')') } - Doc.text("external def ") + Document[Bindable].document(name) + argDoc + Doc.text(" -> ") + res.toDoc + Doc.line + Doc.text("external def ") + Document[Bindable].document( + name + ) + argDoc + Doc.text(" -> ") + res.toDoc + Doc.line case ExternalStruct(nm, typeArgs) => val taDoc = TypeRef.docTypeArgs(typeArgs.toList) { - case None => Doc.empty + case None => Doc.empty case Some(ka) => Doc.text(": ") + Kind.argDoc(ka) } - Doc.text("external struct ") + Document[Constructor].document(nm) + taDoc + Doc.line + Doc.text("external struct ") + Document[Constructor].document( + nm + ) + taDoc + Doc.line } } @@ -353,4 +420,3 @@ object Statement { Doc.intercalate(Doc.empty, stmts.toList.map(document.document(_))) } } - diff --git a/core/src/main/scala/org/bykn/bosatsu/StringUtil.scala b/core/src/main/scala/org/bykn/bosatsu/StringUtil.scala index 04386f529..310720364 100644 --- a/core/src/main/scala/org/bykn/bosatsu/StringUtil.scala +++ b/core/src/main/scala/org/bykn/bosatsu/StringUtil.scala @@ -5,14 +5,16 @@ import cats.parse.{Parser0 => P0, Parser => P} abstract class GenericStringUtil { protected def decodeTable: Map[Char, Char] - private val encodeTable = decodeTable.iterator.map { case (v, k) => (k, s"\\$v") }.toMap + private val encodeTable = decodeTable.iterator.map { case (v, k) => + (k, s"\\$v") + }.toMap private val nonPrintEscape: Array[String] = (0 until 32).map { c => val strHex = c.toHexString val strPad = List.fill(4 - strHex.length)('0').mkString s"\\u$strPad$strHex" - }.toArray + }.toArray val escapedToken: P[Char] = { def parseIntStr(p: P[Any], base: Int): P[Char] = @@ -37,26 +39,30 @@ abstract class GenericStringUtil { P.char('\\') *> after } - /** - * String content without the delimiter - */ + /** String content without the delimiter + */ def undelimitedString1(endP: P[Unit]): P[String] = - escapedToken.orElse((!endP).with1 *> P.anyChar) - .repAs + escapedToken.orElse((!endP).with1 *> P.anyChar).repAs def escapedString(q: Char): P[String] = { val end: P[Unit] = P.char(q) end *> undelimitedString1(end).orElse(P.pure("")) <* end } - def interpolatedString[A](quoteChar: Char, istart: P[Unit], interp: P0[A], iend: P[Unit]): P[List[Either[A, (Region, String)]]] = { + def interpolatedString[A]( + quoteChar: Char, + istart: P[Unit], + interp: P0[A], + iend: P[Unit] + ): P[List[Either[A, (Region, String)]]] = { val strQuote = P.char(quoteChar) val strLit: P[String] = undelimitedString1(strQuote.orElse(istart)) val notStr: P[A] = (istart ~ interp ~ iend).map { case ((_, a), _) => a } val either: P[Either[A, (Region, String)]] = - ((P.index.with1 ~ strLit ~ P.index).map { case ((s, str), l) => Right((Region(s, l), str)) }) + ((P.index.with1 ~ strLit ~ P.index) + .map { case ((s, str), l) => Right((Region(s, l), str)) }) .orElse(notStr.map(Left(_))) (strQuote ~ either.rep0 ~ strQuote).map { case ((_, lst), _) => lst } @@ -65,15 +71,17 @@ abstract class GenericStringUtil { def escape(quoteChar: Char, str: String): String = { // We can ignore escaping the opposite character used for the string // x isn't escaped anyway and is kind of a hack here - val ignoreEscape = if (quoteChar == '\'') '"' else if (quoteChar == '"') '\'' else 'x' + val ignoreEscape = + if (quoteChar == '\'') '"' else if (quoteChar == '"') '\'' else 'x' str.flatMap { c => if (c == ignoreEscape) c.toString - else encodeTable.get(c) match { - case None => - if (c < ' ') nonPrintEscape(c.toInt) - else c.toString - case Some(esc) => esc - } + else + encodeTable.get(c) match { + case None => + if (c < ' ') nonPrintEscape(c.toInt) + else c.toString + case Some(esc) => esc + } } } @@ -95,25 +103,21 @@ abstract class GenericStringUtil { if (idx >= str.length) { // done idx - } - else if (idx < 0) { + } else if (idx < 0) { // error from decodeNum idx - } - else { + } else { val c0 = str.charAt(idx) if (c0 != '\\') { sb.append(c0) loop(idx + 1) - } - else { + } else { // str(idx) == \ val nextIdx = idx + 1 if (nextIdx >= str.length) { // error we expect there to be a character after \ ~idx - } - else { + } else { val c = str.charAt(nextIdx) decodeTable.get(c) match { case Some(d) => @@ -121,10 +125,10 @@ abstract class GenericStringUtil { loop(idx + 2) case None => c match { - case 'o' => loop(decodeNum(idx + 2, 2, 8)) - case 'x' => loop(decodeNum(idx + 2, 2, 16)) - case 'u' => loop(decodeNum(idx + 2, 4, 16)) - case 'U' => loop(decodeNum(idx + 2, 8, 16)) + case 'o' => loop(decodeNum(idx + 2, 2, 8)) + case 'x' => loop(decodeNum(idx + 2, 2, 16)) + case 'u' => loop(decodeNum(idx + 2, 4, 16)) + case 'U' => loop(decodeNum(idx + 2, 8, 16)) case other => // \c is interpretted as just \c, if the character isn't escaped sb.append('\\') @@ -157,7 +161,8 @@ object StringUtil extends GenericStringUtil { ('n', '\n'), ('r', '\r'), ('t', '\t'), - ('v', 11.toChar)) // vertical tab + ('v', 11.toChar) + ) // vertical tab } object JsonStringUtil extends GenericStringUtil { @@ -171,5 +176,6 @@ object JsonStringUtil extends GenericStringUtil { ('f', 12.toChar), // form-feed ('n', '\n'), ('r', '\r'), - ('t', '\t')) + ('t', '\t') + ) } diff --git a/core/src/main/scala/org/bykn/bosatsu/Test.scala b/core/src/main/scala/org/bykn/bosatsu/Test.scala index e43d7d396..1eac323b6 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Test.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Test.scala @@ -8,8 +8,8 @@ sealed abstract class Test { def failures: Option[Test] = this match { - case Test.Assertion(true, _) => None - case f@Test.Assertion(false, _) => Some(f) + case Test.Assertion(true, _) => None + case f @ Test.Assertion(false, _) => Some(f) case Test.Suite(nm, ts) => { val innerFails = ts.flatMap(_.failures.toList) if (innerFails.isEmpty) None @@ -59,7 +59,13 @@ object Test { loop(t, None, 0, 0, Doc.empty) @annotation.tailrec - def loop(ts: List[Test], lastSuite: Option[(Int, Int)], passes: Int, fails: Int, front: Doc): (Int, Int, Doc) = + def loop( + ts: List[Test], + lastSuite: Option[(Int, Int)], + passes: Int, + fails: Int, + front: Doc + ): (Int, Int, Doc) = ts match { case Nil => val sumDoc = @@ -72,10 +78,18 @@ object Test { case Assertion(true, _) :: rest => loop(rest, lastSuite, passes + 1, fails, front) case Assertion(false, label) :: rest => - loop(rest, lastSuite, passes, fails + 1, front + (Doc.line + Doc.text(label) + colonSpace + failDoc)) + loop( + rest, + lastSuite, + passes, + fails + 1, + front + (Doc.line + Doc.text(label) + colonSpace + failDoc) + ) case Suite(label, rest) :: tail => val (p, f, d) = init(rest) - val res = Doc.line + Doc.text(label) + Doc.char(':') + (Doc.lineOrSpace + d).nested(2) + val res = Doc.line + Doc.text(label) + Doc.char( + ':' + ) + (Doc.lineOrSpace + d).nested(2) loop(tail, Some((p, f)), passes + p, fails + f, front + res) } @@ -93,8 +107,8 @@ object Test { case other => // $COVERAGE-OFF$ sys.error(s"expected test value: $other") - // $COVERAGE-ON$ - } + // $COVERAGE-ON$ + } def toSuite(a: ProductValue): Test = a match { case ConsValue(Str(name), ConsValue(VList(tests), UnitValue)) => @@ -102,7 +116,7 @@ object Test { case other => // $COVERAGE-OFF$ sys.error(s"expected test value: $other") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } def toTest(a: Value): Test = @@ -118,9 +132,9 @@ object Test { case unexpected => // $COVERAGE-OFF$ sys.error(s"unreachable if compilation has worked: $unexpected") - // $COVERAGE-ON$ + // $COVERAGE-ON$ - } + } toTest(value) } } diff --git a/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala b/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala index 41d1c3d27..85dc62765 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TotalityCheck.scala @@ -19,58 +19,83 @@ object TotalityCheck { type ListPatElem = ListPart[Pattern[Cons, Type]] sealed abstract class Error - case class ArityMismatch(cons: Cons, in: Pattern[Cons, Type], env: TypeEnv[Any], expected: Int, found: Int) extends Error - case class UnknownConstructor(cons: Cons, in: Pattern[Cons, Type], env: TypeEnv[Any]) extends Error - case class MultipleSplicesInPattern(pat: ListPat[Cons, Type], env: TypeEnv[Any]) extends Error + case class ArityMismatch( + cons: Cons, + in: Pattern[Cons, Type], + env: TypeEnv[Any], + expected: Int, + found: Int + ) extends Error + case class UnknownConstructor( + cons: Cons, + in: Pattern[Cons, Type], + env: TypeEnv[Any] + ) extends Error + case class MultipleSplicesInPattern( + pat: ListPat[Cons, Type], + env: TypeEnv[Any] + ) extends Error case class InvalidStrPat(pat: StrPat, env: TypeEnv[Any]) extends Error sealed abstract class ExprError[A] { def matchExpr: Expr.Match[A] } - case class NonTotalMatch[A](matchExpr: Expr.Match[A], missing: NonEmptyList[Pattern[Cons, Type]]) extends ExprError[A] - case class InvalidPattern[A](matchExpr: Expr.Match[A], err: Error) extends ExprError[A] - case class UnreachableBranches[A](matchExpr: Expr.Match[A], branches: NonEmptyList[Pattern[Cons, Type]]) extends ExprError[A] + case class NonTotalMatch[A]( + matchExpr: Expr.Match[A], + missing: NonEmptyList[Pattern[Cons, Type]] + ) extends ExprError[A] + case class InvalidPattern[A](matchExpr: Expr.Match[A], err: Error) + extends ExprError[A] + case class UnreachableBranches[A]( + matchExpr: Expr.Match[A], + branches: NonEmptyList[Pattern[Cons, Type]] + ) extends ExprError[A] } -/** - * Here is code for performing totality checks of matches. - * One key thing: we can assume that any two patterns are describing the same type, or otherwise - * typechecking cannot pass. So, this allows us to make certain inferences, e.g. - * _ - [_] = [_, _, *_] - * because we know the type must be a list of some kind of [_] is to be a well typed pattern. - * - * similarly, some things are ill-typed: `1 - 'foo'` doesn't make any sense. Those two patterns - * don't describe the same type. - */ +/** Here is code for performing totality checks of matches. One key thing: we + * can assume that any two patterns are describing the same type, or otherwise + * typechecking cannot pass. So, this allows us to make certain inferences, + * e.g. _ - [_] = [_, _, *_] because we know the type must be a list of some + * kind of [_] is to be a well typed pattern. + * + * similarly, some things are ill-typed: `1 - 'foo'` doesn't make any sense. + * Those two patterns don't describe the same type. + */ case class TotalityCheck(inEnv: TypeEnv[Any]) { import TotalityCheck._ - /** - * Constructors must match all items to be legal - */ - private def checkArity(nm: Cons, size: Int, pat: Pattern[Cons, Type]): Res[Unit] = + /** Constructors must match all items to be legal + */ + private def checkArity( + nm: Cons, + size: Int, + pat: Pattern[Cons, Type] + ): Res[Unit] = inEnv.typeConstructors.get(nm) match { case None => Left(NonEmptyList.of(UnknownConstructor(nm, pat, inEnv))) case Some((_, params, _)) => val cmp = params.lengthCompare(size) if (cmp == 0) validUnit - else Left(NonEmptyList.of(ArityMismatch(nm, pat, inEnv, size, params.size))) + else + Left( + NonEmptyList.of(ArityMismatch(nm, pat, inEnv, size, params.size)) + ) } private[this] val validUnit: Res[Unit] = Right(()) - /** - * Check that a given pattern follows all the rules. - * - * The main rules are: - * * in strings, you cannot have two adjacent variable patterns (where should one end?) - * * in lists we cannot have more than one variable pattern (maybe relaxed later to the above) - */ + + /** Check that a given pattern follows all the rules. + * + * The main rules are: * in strings, you cannot have two adjacent variable + * patterns (where should one end?) * in lists we cannot have more than one + * variable pattern (maybe relaxed later to the above) + */ def validatePattern(p: Pattern[Cons, Type]): Res[Unit] = p match { - case lp@ListPat(parts) => + case lp @ ListPat(parts) => val twoAdj = lp.toSeqPattern.toList.sliding(2).exists { case Seq(SeqPart.Wildcard, SeqPart.Wildcard) => true - case _ => false + case _ => false } val outer = if (!twoAdj) validUnit @@ -79,12 +104,12 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { val inners: Res[Unit] = parts.parTraverse_ { case ListPart.Item(p) => validatePattern(p) - case _ => validUnit + case _ => validUnit } (outer, inners).parMapN { (_, _) => () } - case sp@StrPat(_) => + case sp @ StrPat(_) => val simp = sp.toSeqPattern if (simp.normalize == simp) validUnit else Left(NonEmptyList(InvalidStrPat(sp, inEnv), Nil)) @@ -98,20 +123,19 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { case _ => validUnit } - /** - * Check that an expression, and all inner expressions, are total, or return - * a NonEmptyList of matches that are not total - */ + /** Check that an expression, and all inner expressions, are total, or return + * a NonEmptyList of matches that are not total + */ def checkExpr[A](expr: Expr[A]): ValidatedNel[ExprError[A], Unit] = { import Expr._ expr match { - case Annotation(e, _, _) => checkExpr(e) - case Generic(_, e) => checkExpr(e) - case Lambda(_, e, _) => checkExpr(e) + case Annotation(e, _, _) => checkExpr(e) + case Generic(_, e) => checkExpr(e) + case Lambda(_, e, _) => checkExpr(e) case Global(_, _, _) | Local(_, _) | Literal(_, _) => Validated.valid(()) - case App(fn, args, _) => checkExpr(fn) *> args.traverse_(checkExpr) + case App(fn, args, _) => checkExpr(fn) *> args.traverse_(checkExpr) case Let(_, e1, e2, _, _) => checkExpr(e1) *> checkExpr(e2) - case m@Match(arg, branches, _) => + case m @ Match(arg, branches, _) => val patterns = branches.toList.map(_._1) patterns .parTraverse_(validatePattern) @@ -137,7 +161,9 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { val unr = patternSetOps.unreachableBranches(patterns) NonEmptyList.fromList(unr) match { case Some(nel) => - Validated.invalidNel(UnreachableBranches(m, nel): ExprError[A]) + Validated.invalidNel( + UnreachableBranches(m, nel): ExprError[A] + ) case None => Validated.valid(()) } } @@ -164,15 +190,22 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { def isTotal(p: Patterns): Boolean = missingBranches(p).isEmpty - private def structToList(n: Cons, args: List[Pattern[Cons, Type]]): Option[Pattern.ListPat[Cons, Type]] = + private def structToList( + n: Cons, + args: List[Pattern[Cons, Type]] + ): Option[Pattern.ListPat[Cons, Type]] = (n, args) match { - case ((PackageName.PredefName, Constructor("EmptyList")), Nil) => Some(Pattern.ListPat(Nil)) - case ((PackageName.PredefName, Constructor("NonEmptyList")), h :: t :: Nil) => + case ((PackageName.PredefName, Constructor("EmptyList")), Nil) => + Some(Pattern.ListPat(Nil)) + case ( + (PackageName.PredefName, Constructor("NonEmptyList")), + h :: t :: Nil + ) => val tailRes = t match { case Pattern.PositionalStruct(n, a) => structToList(n, a).map(_.parts) case Pattern.ListPat(parts) => Some(parts) - case _ => + case _ => if (isTotal(t :: Nil)) Some(Pattern.ListPart.WildList :: Nil) else None } @@ -189,13 +222,18 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { SetOps.imap[SeqPattern[Pattern[Cons, Type]], ListPat[Cons, Type]]( seqP, ListPat.fromSeqPattern(_), - _.toSeqPattern) + _.toSeqPattern + ) private val strPatternSetOps: SetOps[StrPat] = SetOps.imap[SeqPattern[Char], StrPat]( - SeqPattern.seqPatternSetOps(SeqPart.part1SetOps(SetOps.distinct[Char]), implicitly), + SeqPattern.seqPatternSetOps( + SeqPart.part1SetOps(SetOps.distinct[Char]), + implicitly + ), StrPat.fromSeqPattern(_), - _.toSeqPattern) + _.toSeqPattern + ) private val getProd: Int => SetOps[List[Pattern[Cons, Type]]] = memoizeDagHashed[Int, SetOps[List[Pattern[Cons, Type]]]] { @@ -212,8 +250,9 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { { case (h, t) => h :: t }, { case h :: t => (h, t) - case _ => sys.error(s"invalid arity: $arity, found empty list") - }) + case _ => sys.error(s"invalid arity: $arity, found empty list") + } + ) } lazy val patternSetOps: SetOps[Pattern[Cons, Type]] = @@ -221,108 +260,126 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { val top = Some(WildCard) def intersection( - left: Pattern[Cons, Type], - right: Pattern[Cons, Type]): List[Pattern[Cons, Type]] = - (left, right) match { - case (Var(va), Var(vb)) => Var(Ordering[Bindable].min(va, vb)) :: Nil - case (Named(va, pa), Named(vb, pb)) if va == vb => - intersection(pa, pb).map(Named(va, _)) - case (Named(_, pa), r) => intersection(pa, r) - case (l, Named(_, pb)) => intersection(l, pb) - case (WildCard, v) => v :: Nil - case (v, WildCard) => v :: Nil - case (_, _) if left == right => left :: Nil - case (Var(_), v) => v :: Nil - case (v, Var(_)) => v :: Nil - case (Annotation(p, _), t) => intersection(p, t) - case (t, Annotation(p, _)) => intersection(t, p) - case (Literal(a), Literal(b)) => - if (a == b) left :: Nil - else Nil - case (Literal(Lit.Str(s)), p@StrPat(_)) => - if (p.matches(s)) left :: Nil - else Nil - case (p@StrPat(_), Literal(Lit.Str(s))) => - if (p.matches(s)) right :: Nil - else Nil - case (p1@StrPat(_), p2@StrPat(_)) => - strPatternSetOps.intersection(p1, p2) - case (lp@ListPat(_), rp@ListPat(_)) => - listPatternSetOps.intersection(lp, rp) - case (PositionalStruct(n, as), rp@ListPat(_)) => - structToList(n, as) match { - case Some(lp) => intersection(lp, rp) - case None => - if (isTop(rp)) left :: Nil - else Nil - } - case (lp@ListPat(_), PositionalStruct(n, as)) => - structToList(n, as) match { - case Some(rp) => intersection(lp, rp) - case None => - if (isTop(lp)) right :: Nil - else Nil - } - case (PositionalStruct(ln, lps), PositionalStruct(rn, rps)) => - if (ln == rn) { - val la = lps.size - if (rps.size == la) { - // the arity must match or check expr fails - // if the arity doesn't match, just consider this - // a mismatch - unifyUnion(getProd(la).intersection(lps, rps) - .map(PositionalStruct(ln, _): Pattern[Cons, Type])) - } + left: Pattern[Cons, Type], + right: Pattern[Cons, Type] + ): List[Pattern[Cons, Type]] = + (left, right) match { + case (Var(va), Var(vb)) => Var(Ordering[Bindable].min(va, vb)) :: Nil + case (Named(va, pa), Named(vb, pb)) if va == vb => + intersection(pa, pb).map(Named(va, _)) + case (Named(_, pa), r) => intersection(pa, r) + case (l, Named(_, pb)) => intersection(l, pb) + case (WildCard, v) => v :: Nil + case (v, WildCard) => v :: Nil + case (_, _) if left == right => left :: Nil + case (Var(_), v) => v :: Nil + case (v, Var(_)) => v :: Nil + case (Annotation(p, _), t) => intersection(p, t) + case (t, Annotation(p, _)) => intersection(t, p) + case (Literal(a), Literal(b)) => + if (a == b) left :: Nil + else Nil + case (Literal(Lit.Str(s)), p @ StrPat(_)) => + if (p.matches(s)) left :: Nil + else Nil + case (p @ StrPat(_), Literal(Lit.Str(s))) => + if (p.matches(s)) right :: Nil + else Nil + case (p1 @ StrPat(_), p2 @ StrPat(_)) => + strPatternSetOps.intersection(p1, p2) + case (lp @ ListPat(_), rp @ ListPat(_)) => + listPatternSetOps.intersection(lp, rp) + case (PositionalStruct(n, as), rp @ ListPat(_)) => + structToList(n, as) match { + case Some(lp) => intersection(lp, rp) + case None => + if (isTop(rp)) left :: Nil else Nil - } - else Nil - case (Union(a, b), p) => - val u = unifyUnion(a :: b.toList) - unifyUnion(u.flatMap(intersection(_, p))) - case (p, Union(a, b)) => - val u = unifyUnion(a :: b.toList) - unifyUnion(u.flatMap(intersection(p, _))) - case (_, _) => - if (isTopCheap(right)) left :: Nil - else if (isTopCheap(left)) right :: Nil - else Nil - } + } + case (lp @ ListPat(_), PositionalStruct(n, as)) => + structToList(n, as) match { + case Some(rp) => intersection(lp, rp) + case None => + if (isTop(lp)) right :: Nil + else Nil + } + case (PositionalStruct(ln, lps), PositionalStruct(rn, rps)) => + if (ln == rn) { + val la = lps.size + if (rps.size == la) { + // the arity must match or check expr fails + // if the arity doesn't match, just consider this + // a mismatch + unifyUnion( + getProd(la) + .intersection(lps, rps) + .map(PositionalStruct(ln, _): Pattern[Cons, Type]) + ) + } else Nil + } else Nil + case (Union(a, b), p) => + val u = unifyUnion(a :: b.toList) + unifyUnion(u.flatMap(intersection(_, p))) + case (p, Union(a, b)) => + val u = unifyUnion(a :: b.toList) + unifyUnion(u.flatMap(intersection(p, _))) + case (_, _) => + if (isTopCheap(right)) left :: Nil + else if (isTopCheap(left)) right :: Nil + else Nil + } - def difference(left: Pattern[Cons, Type], right: Pattern[Cons, Type]): Patterns = + def difference( + left: Pattern[Cons, Type], + right: Pattern[Cons, Type] + ): Patterns = (left, right) match { - case (_, WildCard | Var(_)) => Nil - case (Named(_, p), r) => difference(p, r) - case (l, Named(_, p)) => difference(l, p) - case (Annotation(p, _), r) => difference(p, r) - case (l, Annotation(p, _)) => difference(l, p) - case (Var(v), listPat@ListPat(_)) => + case (_, WildCard | Var(_)) => Nil + case (Named(_, p), r) => difference(p, r) + case (l, Named(_, p)) => difference(l, p) + case (Annotation(p, _), r) => difference(p, r) + case (l, Annotation(p, _)) => difference(l, p) + case (Var(v), listPat @ ListPat(_)) => // v is the same as [*v] for well typed expressions - listPatternSetOps.difference(ListPat(ListPart.NamedList(v) :: Nil), listPat) - case (left@ListPat(_), right@ListPat(_)) => + listPatternSetOps.difference( + ListPat(ListPart.NamedList(v) :: Nil), + listPat + ) + case (left @ ListPat(_), right @ ListPat(_)) => listPatternSetOps.difference(left, right) - case (_, listPat@ListPat(_)) if isTop(left) => + case (_, listPat @ ListPat(_)) if isTop(left) => // _ is the same as [*_] for well typed expressions - listPatternSetOps.difference(ListPat(ListPart.WildList :: Nil), listPat) - case (_, listPat@ListPat(_)) if listPatternSetOps.isTop(listPat) => + listPatternSetOps.difference( + ListPat(ListPart.WildList :: Nil), + listPat + ) + case (_, listPat @ ListPat(_)) if listPatternSetOps.isTop(listPat) => Nil - case (Literal(Lit.Str(str)), p@StrPat(_)) if p.matches(str) => + case (Literal(Lit.Str(str)), p @ StrPat(_)) if p.matches(str) => Nil - case (sa@StrPat(_), Literal(Lit.Str(str))) => - if (sa.matches(str)) strPatternSetOps.difference(sa, StrPat.fromLitStr(str)) + case (sa @ StrPat(_), Literal(Lit.Str(str))) => + if (sa.matches(str)) + strPatternSetOps.difference(sa, StrPat.fromLitStr(str)) else sa :: Nil - case (sa@StrPat(_), sb@StrPat(_)) => + case (sa @ StrPat(_), sb @ StrPat(_)) => strPatternSetOps.difference(sa, sb) - case (WildCard, right@StrPat(_)) => + case (WildCard, right @ StrPat(_)) => // _ is the same as "${_}" for well typed expressions - strPatternSetOps.difference(StrPat(NonEmptyList(StrPart.WildStr, Nil)), right) - case (Var(v), right@StrPat(_)) => + strPatternSetOps.difference( + StrPat(NonEmptyList(StrPart.WildStr, Nil)), + right + ) + case (Var(v), right @ StrPat(_)) => // v is the same as "${v}" for well typed expressions - strPatternSetOps.difference(StrPat(NonEmptyList(StrPart.NamedStr(v), Nil)), right) - case (llit@Literal(l), Literal(r)) => + strPatternSetOps.difference( + StrPat(NonEmptyList(StrPart.NamedStr(v), Nil)), + right + ) + case (llit @ Literal(l), Literal(r)) => if (l == r) Nil else (llit :: Nil) // below here it is starting to get complex - case (_, _) if disjoint(left, right) => left :: Nil + case (_, _) if disjoint(left, right) => left :: Nil case (_, _) if (left == right) || subset(left, right) => Nil case (Union(a, b), Union(c, d)) => unifyUnion(differenceAll(a :: b.toList, c :: d.toList)) @@ -339,23 +396,24 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { // the arity must match or check expr fails // if the arity doesn't match, just consider this // a mismatch - unifyUnion(getProd(la).difference(lp, rp) - .map(PositionalStruct(ln, _): Pattern[Cons, Type])) - } - else (left :: Nil) - } - else { + unifyUnion( + getProd(la) + .difference(lp, rp) + .map(PositionalStruct(ln, _): Pattern[Cons, Type]) + ) + } else (left :: Nil) + } else { left :: Nil } - case (PositionalStruct(n, as), rp@ListPat(_)) => + case (PositionalStruct(n, as), rp @ ListPat(_)) => structToList(n, as) match { case Some(lp) => difference(lp, rp) - case None => left :: Nil + case None => left :: Nil } - case (lp@ListPat(_), PositionalStruct(n, as)) => + case (lp @ ListPat(_), PositionalStruct(n, as)) => structToList(n, as) match { case Some(rp) => difference(lp, rp) - case None => left :: Nil + case None => left :: Nil } case (_, PositionalStruct(nm, _)) if isTop(left) => inEnv.definedTypeFor(nm) match { @@ -367,7 +425,8 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { dt.constructors.flatMap { case cf if (dt.packageName, cf.name) == nm => // we can replace _ with Struct(_, _...) - val newWild = PositionalStruct(nm, cf.args.map(_ => WildCard)) + val newWild = + PositionalStruct(nm, cf.args.map(_ => WildCard)) difference(newWild, right) case cf => @@ -377,7 +436,10 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { if (Type.hasNoVars(t._2)) Annotation(WildCard, t._2) else WildCard - PositionalStruct((dt.packageName, cf.name), cf.args.map(argToPat)) :: Nil + PositionalStruct( + (dt.packageName, cf.name), + cf.args.map(argToPat) + ) :: Nil } } case (_, _) => @@ -385,24 +447,23 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { if (isTop(right)) Nil else if (isTop(left)) { right match { - case StrPat(_) => difference(StrPat.Wild, right) + case StrPat(_) => difference(StrPat.Wild, right) case ListPat(_) => difference(ListPat.Wild, right) - case _ => + case _ => // we can't solve this left :: Nil } - } - else left :: Nil + } else left :: Nil } def isTop(p: Pattern[Cons, Type]): Boolean = p match { case Pattern.WildCard | Pattern.Var(_) => true - case Pattern.Named(_, p) => isTop(p) - case Pattern.Annotation(p, _) => isTop(p) - case Pattern.Literal(_) => false // literals are not total - case s@Pattern.StrPat(_) => strPatternSetOps.isTop(s) - case l@Pattern.ListPat(_) => listPatternSetOps.isTop(l) + case Pattern.Named(_, p) => isTop(p) + case Pattern.Annotation(p, _) => isTop(p) + case Pattern.Literal(_) => false // literals are not total + case s @ Pattern.StrPat(_) => strPatternSetOps.isTop(s) + case l @ Pattern.ListPat(_) => listPatternSetOps.isTop(l) case Pattern.PositionalStruct(name, params) => inEnv.definedTypeFor(name) match { case None => @@ -422,11 +483,11 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { private def isTopCheap(p: Pattern[Cons, Type]): Boolean = p match { case Pattern.WildCard | Pattern.Var(_) => true - case Pattern.Named(_, p) => isTopCheap(p) - case Pattern.Annotation(p, _) => isTopCheap(p) - case Pattern.Literal(_) => false // literals are not total - case s@Pattern.StrPat(_) => strPatternSetOps.isTop(s) - case l@Pattern.ListPat(_) => listPatternSetOps.isTop(l) + case Pattern.Named(_, p) => isTopCheap(p) + case Pattern.Annotation(p, _) => isTopCheap(p) + case Pattern.Literal(_) => false // literals are not total + case s @ Pattern.StrPat(_) => strPatternSetOps.isTop(s) + case l @ Pattern.ListPat(_) => listPatternSetOps.isTop(l) case Pattern.PositionalStruct(name, params) => inEnv.definedTypeFor(name) match { case None => @@ -446,17 +507,21 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { def loop(a: Pattern[Cons, Type], b: Pattern[Cons, Type]): Boolean = (a, b) match { - case _ if a == b => true + case _ if a == b => true case (_, Union(h, t)) => (h :: t).exists(loop(a, _)) - case (Literal(Lit.Str(s)), sp@Pattern.StrPat(_)) => sp.matches(s) - case (s1@Pattern.StrPat(_), s2@Pattern.StrPat(_)) => + case (Literal(Lit.Str(s)), sp @ Pattern.StrPat(_)) => + sp.matches(s) + case (s1 @ Pattern.StrPat(_), s2 @ Pattern.StrPat(_)) => strPatternSetOps.subset(s1, s2) - case (l1@Pattern.ListPat(_), l2@Pattern.ListPat(_)) => + case (l1 @ Pattern.ListPat(_), l2 @ Pattern.ListPat(_)) => listPatternSetOps.subset(l1, l2) - case (Pattern.PositionalStruct(ln, lp), Pattern.PositionalStruct(rn, rp)) => + case ( + Pattern.PositionalStruct(ln, lp), + Pattern.PositionalStruct(rn, rp) + ) => (ln == rn) && - (lp.size == rp.size) && - (lp.zip(rp).forall { case (l, r) => loop(l, r) }) + (lp.size == rp.size) && + (lp.zip(rp).forall { case (l, r) => loop(l, r) }) case _ => false } @@ -467,15 +532,20 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { if (u0.exists(isTopCheap(_))) WildCard :: Nil else { - val u = u0.flatMap { p => Pattern.flatten(normalizePattern(p)).toList } + val u = u0.flatMap { p => + Pattern.flatten(normalizePattern(p)).toList + } - val lps = listPatternSetOps.unifyUnion(u.collect { case lp@Pattern.ListPat(_) => lp }) - val sps = strPatternSetOps.unifyUnion(u.collect { case sp@Pattern.StrPat(_) => sp }) - val strs = u.collect { case Pattern.Literal(ls@Lit.Str(_)) => ls } + val lps = listPatternSetOps.unifyUnion(u.collect { + case lp @ Pattern.ListPat(_) => lp + }) + val sps = strPatternSetOps.unifyUnion(u.collect { + case sp @ Pattern.StrPat(_) => sp + }) + val strs = u.collect { case Pattern.Literal(ls @ Lit.Str(_)) => ls } val distinctStrs = - strs - .distinct + strs.distinct .filterNot { s => sps.exists(_.matches(s.toStr)) } @@ -483,10 +553,11 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { .map(Pattern.Literal(_)) val notListStr = u.filterNot { - case Pattern.ListPat(_) | Pattern.StrPat(_) | Pattern.Literal(Lit.Str(_)) | Pattern.PositionalStruct(_, _) => true + case Pattern.ListPat(_) | Pattern.StrPat(_) | + Pattern.Literal(Lit.Str(_)) | Pattern.PositionalStruct(_, _) => + true case _ => false - } - .distinct + }.distinct val structs = u .collect { case Pattern.PositionalStruct(n, a) => (n, a) } @@ -499,34 +570,33 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { } .toList - (lps ::: sps ::: distinctStrs ::: notListStr ::: structs).sorted + (lps ::: sps ::: distinctStrs ::: notListStr ::: structs).sorted } } - /** - * recursively replace as much as possible with Wildcard - * This should match exactly the same set for the same type as - * the previous pattern, without any binding names - */ + /** recursively replace as much as possible with Wildcard This should match + * exactly the same set for the same type as the previous pattern, without + * any binding names + */ def normalizePattern(p: Pattern[Cons, Type]): Pattern[Cons, Type] = p match { case WildCard | Literal(_) => p - case Var(_) => WildCard - case Named(_, p) => normalizePattern(p) - case Annotation(p, _) => normalizePattern(p) - case u@Union(_, _) => + case Var(_) => WildCard + case Named(_, p) => normalizePattern(p) + case Annotation(p, _) => normalizePattern(p) + case u @ Union(_, _) => val flattened = Pattern.flatten(u).map(normalizePattern(_)) patternSetOps.unifyUnion(flattened.toList) match { case Nil => // $COVERAGE-OFF$ sys.error("unreachable: union can't remove items") - // $COVERAGE-ON$ + // $COVERAGE-ON$ case h :: t => Pattern.union(h, t) } case _ if patternSetOps.isTop(p) => WildCard - case strPat@StrPat(_) => + case strPat @ StrPat(_) => StrPat.fromSeqPattern(strPat.toSeqPattern) case ListPat(parts) => val p1 = @@ -541,15 +611,13 @@ case class TotalityCheck(inEnv: TypeEnv[Any]) { case PositionalStruct(n, params) => val normParams = params.map(normalizePattern) structToList(n, normParams) match { - case None => PositionalStruct(n, normParams) + case None => PositionalStruct(n, normParams) case Some(lp) => lp } } - /** - * This tells if two patterns for the same type - * would match the same values - */ + /** This tells if two patterns for the same type would match the same values + */ val eqPat: Eq[Pattern[Cons, Type]] = new Eq[Pattern[Cons, Type]] { def eqv(l: Pattern[Cons, Type], r: Pattern[Cons, Type]) = diff --git a/core/src/main/scala/org/bykn/bosatsu/TypeParser.scala b/core/src/main/scala/org/bykn/bosatsu/TypeParser.scala index a3eba3ed5..6e4222863 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TypeParser.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TypeParser.scala @@ -2,9 +2,16 @@ package org.bykn.bosatsu import cats.data.NonEmptyList import cats.parse.{Parser => P} -import org.typelevel.paiges.{ Doc, Document } - -import Parser.{ Combinators, MaybeTupleOrParens, lowerIdent, maybeSpace, maybeSpacesAndLines, keySpace } +import org.typelevel.paiges.{Doc, Document} + +import Parser.{ + Combinators, + MaybeTupleOrParens, + lowerIdent, + maybeSpace, + maybeSpacesAndLines, + keySpace +} abstract class TypeParser[A] { /* @@ -21,19 +28,26 @@ abstract class TypeParser[A] { */ protected def unapplyRoot(a: A): Option[Doc] protected def unapplyFn(a: A): Option[(NonEmptyList[A], A)] - protected def unapplyUniversal(a: A): Option[(List[(String, Option[Kind])], A)] + protected def unapplyUniversal( + a: A + ): Option[(List[(String, Option[Kind])], A)] protected def unapplyTypeApply(a: A): Option[(A, List[A])] protected def unapplyTuple(a: A): Option[List[A]] - final val parser: P[A] = P.recursive[A] { recurse => val univItem: P[(String, Option[Kind])] = { val kindP: P[Kind] = - (maybeSpacesAndLines.soft.with1 *> (P.char(':') *> maybeSpacesAndLines *> Kind.parser)) + (maybeSpacesAndLines.soft.with1 *> (P.char( + ':' + ) *> maybeSpacesAndLines *> Kind.parser)) lowerIdent ~ kindP.? } val lambda: P[MaybeTupleOrParens[A]] = - (keySpace("forall") *> univItem.nonEmptyListOfWs(maybeSpacesAndLines) ~ (maybeSpacesAndLines *> P.char('.') *> maybeSpacesAndLines *> recurse)) + (keySpace("forall") *> univItem.nonEmptyListOfWs( + maybeSpacesAndLines + ) ~ (maybeSpacesAndLines *> P.char( + '.' + ) *> maybeSpacesAndLines *> recurse)) .map { case (args, e) => MaybeTupleOrParens.Bare(universal(args, e)) } val tupleOrParens: P[MaybeTupleOrParens[A]] = @@ -41,22 +55,25 @@ abstract class TypeParser[A] { def nonArrow(mtp: MaybeTupleOrParens[A]): A = mtp match { - case MaybeTupleOrParens.Bare(a) => a + case MaybeTupleOrParens.Bare(a) => a case MaybeTupleOrParens.Parens(a) => a case MaybeTupleOrParens.Tuple(as) => makeTuple(as) } val appP: P[MaybeTupleOrParens[A] => MaybeTupleOrParens[A]] = - (P.char('[') *> maybeSpacesAndLines *> recurse.nonEmptyListOfWs(maybeSpacesAndLines) <* maybeSpacesAndLines <* P.char(']')) + (P.char('[') *> maybeSpacesAndLines *> recurse.nonEmptyListOfWs( + maybeSpacesAndLines + ) <* maybeSpacesAndLines <* P.char(']')) .map { args => - { left => MaybeTupleOrParens.Bare(applyTypes(nonArrow(left), args)) } } val arrowP: P[MaybeTupleOrParens[A] => MaybeTupleOrParens[A]] = - ((maybeSpace.with1.soft ~ P.string("->") ~ maybeSpacesAndLines) *> recurse) + ((maybeSpace.with1.soft ~ P.string( + "->" + ) ~ maybeSpacesAndLines) *> recurse) // TODO remove the flatMap when we support FunctionN .map { right => { @@ -66,7 +83,7 @@ abstract class TypeParser[A] { MaybeTupleOrParens.Bare(makeFn(NonEmptyList.one(a), right)) case MaybeTupleOrParens.Tuple(items) => val args = NonEmptyList.fromList(items) match { - case None => NonEmptyList.one(makeTuple(Nil)) + case None => NonEmptyList.one(makeTuple(Nil)) case Some(nel) => nel } // We know th @@ -74,8 +91,11 @@ abstract class TypeParser[A] { } } - P.oneOf(lambda :: parseRoot.map(MaybeTupleOrParens.Bare(_)) :: tupleOrParens :: Nil) - .maybeAp(appP) + P.oneOf( + lambda :: parseRoot.map( + MaybeTupleOrParens.Bare(_) + ) :: tupleOrParens :: Nil + ).maybeAp(appP) .maybeAp(arrowP) .map(nonArrow) } @@ -93,7 +113,7 @@ abstract class TypeParser[A] { case None => () case Some(ts) => return ts match { - case Nil => unitDoc + case Nil => unitDoc case h :: Nil => Doc.char('(') + toDoc(h) + commaPar case twoAndMore => p(Doc.intercalate(commaSpace, twoAndMore.map(toDoc))) @@ -106,21 +126,22 @@ abstract class TypeParser[A] { val args = if (ins.tail.isEmpty) { val in0 = ins.head val din = toDoc(in0) - unapplyFn(in0).orElse(unapplyUniversal(in0)).orElse(unapplyTuple(in0)) match { + unapplyFn(in0) + .orElse(unapplyUniversal(in0)) + .orElse(unapplyTuple(in0)) match { case Some(_) => par(din) - case None => din + case None => din } - } - else { + } else { // there is more than 1 arg so parens are always used: (a, b) -> c par(Doc.intercalate(commaSpace, ins.toList.map(toDoc))) - + } return (args + (spaceArrow + toDoc(out))) } unapplyRoot(a) match { - case None => () + case None => () case Some(d) => return d } @@ -129,20 +150,25 @@ abstract class TypeParser[A] { case Some((of, args)) => val ofDoc0 = toDoc(of) val ofDoc = unapplyUniversal(of) match { - case None => ofDoc0 + case None => ofDoc0 case Some(_) => par(ofDoc0) } - return ofDoc + Doc.char('[') + Doc.intercalate(commaSpace, args.map(toDoc)) + Doc.char(']') + return ofDoc + Doc.char('[') + Doc.intercalate( + commaSpace, + args.map(toDoc) + ) + Doc.char(']') } unapplyUniversal(a) match { case None => () case Some((vars, in)) => - return forAll + Doc.intercalate(commaSpace, + return forAll + Doc.intercalate( + commaSpace, vars.map { - case (a, None) => Doc.text(a) + case (a, None) => Doc.text(a) case (a, Some(k)) => Doc.text(a) + TypeParser.colonSpace + k.toDoc - }) + + } + ) + Doc.char('.') + Doc.space + toDoc(in) } diff --git a/core/src/main/scala/org/bykn/bosatsu/TypeRef.scala b/core/src/main/scala/org/bykn/bosatsu/TypeRef.scala index 63798ddf2..f55d3fc3b 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TypeRef.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TypeRef.scala @@ -6,29 +6,27 @@ import cats.implicits._ import cats.parse.{Parser => P, Parser0} import org.bykn.bosatsu.rankn.Type import org.bykn.bosatsu.{TypeName => Name} -import org.typelevel.paiges.{ Doc, Document } +import org.typelevel.paiges.{Doc, Document} import Parser.{lowerIdent, maybeSpace, Combinators} -/** - * This AST is the syntactic version of Type - * it is shaped slightly differently to match the way - * the syntax looks (nested non empty lists are explicit - * whereas we use a recursion/cons style in Type - */ +/** This AST is the syntactic version of Type it is shaped slightly differently + * to match the way the syntax looks (nested non empty lists are explicit + * whereas we use a recursion/cons style in Type + */ sealed abstract class TypeRef { import TypeRef._ def toDoc: Doc = TypeRef.document.document(this) - /** - * Nested TypeForAll can be combined, and should be generally - */ + /** Nested TypeForAll can be combined, and should be generally + */ def normalizeForAll: TypeRef = this match { case TypeVar(_) | TypeName(_) => this - case TypeArrow(a, b) => TypeArrow(a.map(_.normalizeForAll), b.normalizeForAll) + case TypeArrow(a, b) => + TypeArrow(a.map(_.normalizeForAll), b.normalizeForAll) case TypeApply(a, bs) => TypeApply(a.normalizeForAll, bs.map(_.normalizeForAll)) case TypeForAll(pars0, TypeForAll(pars1, e)) => @@ -36,10 +34,13 @@ sealed abstract class TypeRef { TypeForAll(pars0 ::: pars1, e).normalizeForAll case TypeForAll(pars, e) => // Remove `Some(Type)` since that's the default - TypeForAll(pars.map { - case (v, Some(Kind.Type)) => (v, None) - case other => other - }, e.normalizeForAll) + TypeForAll( + pars.map { + case (v, Some(Kind.Type)) => (v, None) + case other => other + }, + e.normalizeForAll + ) case TypeTuple(ts) => TypeTuple(ts.map(_.normalizeForAll)) } @@ -48,10 +49,9 @@ sealed abstract class TypeRef { object TypeRef { private val colonSpace = Doc.text(": ") - def argDoc[A: Document](st: (A, Option[TypeRef])): Doc = st match { - case (s, None) => Document[A].document(s) + case (s, None) => Document[A].document(s) case (s, Some(tr)) => Document[A].document(s) + colonSpace + (tr.toDoc) } @@ -69,7 +69,10 @@ object TypeRef { case class TypeApply(of: TypeRef, args: NonEmptyList[TypeRef]) extends TypeRef - case class TypeForAll(params: NonEmptyList[(TypeVar, Option[Kind])], in: TypeRef) extends TypeRef + case class TypeForAll( + params: NonEmptyList[(TypeVar, Option[Kind])], + in: TypeRef + ) extends TypeRef case class TypeTuple(params: List[TypeRef]) extends TypeRef implicit val typeRefOrdering: Ordering[TypeRef] = @@ -81,30 +84,35 @@ object TypeRef { def compare(a: TypeRef, b: TypeRef): Int = (a, b) match { - case (TypeVar(v0), TypeVar(v1)) => v0.compareTo(v1) - case (TypeVar(_), _) => -1 + case (TypeVar(v0), TypeVar(v1)) => v0.compareTo(v1) + case (TypeVar(_), _) => -1 case (TypeName(v0), TypeName(v1)) => Ordering[Name].compare(v0, v1) - case (TypeName(_), TypeVar(_)) => 1 - case (TypeName(_), _) => -1 + case (TypeName(_), TypeVar(_)) => 1 + case (TypeName(_), _) => -1 case (TypeArrow(a0, b0), TypeArrow(a1, b1)) => val c = nelTR.compare(a0, a1) if (c == 0) compare(b0, b1) else c case (TypeArrow(_, _), TypeVar(_) | TypeName(_)) => 1 - case (TypeArrow(_, _), _) => -1 + case (TypeArrow(_, _), _) => -1 case (TypeApply(o0, a0), TypeApply(o1, a1)) => val c = compare(o0, o1) if (c != 0) c else list.compare(a0.toList, a1.toList) - case (TypeApply(_, _), TypeVar(_) | TypeName(_) | TypeArrow(_, _)) => 1 - case (TypeApply(_, _), _) => -1 + case (TypeApply(_, _), TypeVar(_) | TypeName(_) | TypeArrow(_, _)) => + 1 + case (TypeApply(_, _), _) => -1 case (TypeForAll(p0, in0), TypeForAll(p1, in1)) => // TODO, we could normalize the parmeters here val c = nelistKind.compare(p0, p1) if (c == 0) compare(in0, in1) else c - case (TypeForAll(_, _), TypeVar(_) | TypeName(_) | TypeArrow(_, _) | TypeApply(_, _)) => 1 - case (TypeForAll(_, _), _) => -1 + case ( + TypeForAll(_, _), + TypeVar(_) | TypeName(_) | TypeArrow(_, _) | TypeApply(_, _) + ) => + 1 + case (TypeForAll(_, _), _) => -1 case (TypeTuple(t0), TypeTuple(t1)) => list.compare(t0, t1) - case (TypeTuple(_), _) => 1 + case (TypeTuple(_), _) => 1 } } @@ -117,7 +125,8 @@ object TypeRef { } def makeFn(in: NonEmptyList[TypeRef], out: TypeRef) = TypeArrow(in, out) - def applyTypes(cons: TypeRef, args: NonEmptyList[TypeRef]) = TypeApply(cons, args) + def applyTypes(cons: TypeRef, args: NonEmptyList[TypeRef]) = + TypeApply(cons, args) def universal(vars: NonEmptyList[(String, Option[Kind])], in: TypeRef) = TypeForAll(vars.map { case (s, k) => (TypeVar(s), k) }, in) @@ -126,32 +135,35 @@ object TypeRef { def unapplyRoot(a: TypeRef): Option[Doc] = a match { case TypeName(n) => Some(Document[Identifier].document(n.ident)) - case TypeVar(s) => Some(Doc.text(s)) - case _ => None + case TypeVar(s) => Some(Doc.text(s)) + case _ => None } def unapplyFn(a: TypeRef): Option[(NonEmptyList[TypeRef], TypeRef)] = a match { case TypeArrow(a, b) => Some((a, b)) - case _ => None + case _ => None } - def unapplyUniversal(a: TypeRef): Option[(List[(String, Option[Kind])], TypeRef)] = + def unapplyUniversal( + a: TypeRef + ): Option[(List[(String, Option[Kind])], TypeRef)] = a match { - case TypeForAll(vs, a) => Some(((vs.map { case (v, k) => (v.asString, k) }).toList, a)) + case TypeForAll(vs, a) => + Some(((vs.map { case (v, k) => (v.asString, k) }).toList, a)) case _ => None } def unapplyTypeApply(a: TypeRef): Option[(TypeRef, List[TypeRef])] = a match { case TypeApply(a, args) => Some((a, args.toList)) - case _ => None + case _ => None } def unapplyTuple(a: TypeRef): Option[List[TypeRef]] = a match { case TypeTuple(as) => Some(as) - case _ => None + case _ => None } } @@ -165,7 +177,9 @@ object TypeRef { targs match { case Nil => Doc.empty case nonEmpty => - val params = nonEmpty.map { case (TypeRef.TypeVar(v), a) => Doc.text(v) + aDoc(a) } + val params = nonEmpty.map { case (TypeRef.TypeVar(v), a) => + Doc.text(v) + aDoc(a) + } Doc.char('[') + Doc.intercalate(Doc.text(", "), params) + Doc.char(']') } @@ -174,4 +188,3 @@ object TypeRef { nel.map { case (s, a) => (TypeRef.TypeVar(s.intern), a) } } } - diff --git a/core/src/main/scala/org/bykn/bosatsu/TypeRefConverter.scala b/core/src/main/scala/org/bykn/bosatsu/TypeRefConverter.scala index 3a65f9fd2..0c1e3def5 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TypeRefConverter.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TypeRefConverter.scala @@ -8,18 +8,20 @@ import org.bykn.bosatsu.Identifier.Constructor import cats.implicits._ object TypeRefConverter { - /** - * given the ability to convert a name to a fully resolved - * type constant, convert TypeRef to Type - */ - def apply[F[_]: Applicative](t: TypeRef)(nameToType: Constructor => F[Type.Const]): F[Type] = { + + /** given the ability to convert a name to a fully resolved type constant, + * convert TypeRef to Type + */ + def apply[F[_]: Applicative]( + t: TypeRef + )(nameToType: Constructor => F[Type.Const]): F[Type] = { def toType(t: TypeRef): F[Type] = apply(t)(nameToType) import Type._ import TypeRef._ t match { - case TypeVar(v) => Applicative[F].pure(TyVar(Type.Var.Bound(v))) + case TypeVar(v) => Applicative[F].pure(TyVar(Type.Var.Bound(v))) case TypeName(n) => nameToType(n.ident).map(TyConst(_)) case TypeArrow(as, b) => (as.traverse(toType(_)), toType(b)).mapN(Fun(_, _)) @@ -27,13 +29,16 @@ object TypeRefConverter { (toType(a), bs.toList.traverse(toType)).mapN(Type.applyAll(_, _)) case TypeForAll(pars, e) => toType(e).map { te => - Type.forAll(pars.map { case (TypeVar(v), optK) => - val k = optK match { - case None => Kind.Type - case Some(k) => k - } - (Type.Var.Bound(v), k) - }.toList, te) + Type.forAll( + pars.map { case (TypeVar(v), optK) => + val k = optK match { + case None => Kind.Type + case Some(k) => k + } + (Type.Var.Bound(v), k) + }.toList, + te + ) } case TypeTuple(ts) => ts.traverse(toType).map(Type.Tuple(_)) @@ -41,10 +46,11 @@ object TypeRefConverter { } def fromTypeA[F[_]: Applicative]( - tpe: Type, - onSkolem: Type.Var.Skolem => F[TypeRef], - onMeta: Long => F[TypeRef], - onConst: Type.Const.Defined => F[TypeRef]): F[TypeRef] = { + tpe: Type, + onSkolem: Type.Var.Skolem => F[TypeRef], + onMeta: Long => F[TypeRef], + onConst: Type.Const.Defined => F[TypeRef] + ): F[TypeRef] = { import Type._ import TypeRef._ @@ -60,17 +66,16 @@ object TypeRefConverter { case Type.Tuple(ts) => // this needs to be above TyConst ts.traverse(loop(_)).map(TypeTuple(_)) - case TyConst(defined@Type.Const.Defined(_, _)) => + case TyConst(defined @ Type.Const.Defined(_, _)) => onConst(defined) case Type.Fun(args, to) => (args.traverse(loop), loop(to)).mapN { (ftr, ttr) => TypeArrow(ftr, ttr) } - case ta@TyApply(_, _) => + case ta @ TyApply(_, _) => val (on, args) = unapplyAll(ta) - (loop(on), args.traverse(loop)).mapN { - (of, arg1) => - TypeApply(of, NonEmptyList.fromListUnsafe(arg1)) + (loop(on), args.traverse(loop)).mapN { (of, arg1) => + TypeApply(of, NonEmptyList.fromListUnsafe(arg1)) } case TyVar(tv) => tv match { @@ -85,7 +90,7 @@ object TypeRefConverter { case other => // the extractors mess this up sys.error(s"unreachable: $other") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } } diff --git a/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala b/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala index 97913ece0..5cc308805 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala @@ -5,7 +5,7 @@ import cats.arrow.FunctionK import cats.data.{NonEmptyList, Writer} import cats.implicits._ import org.bykn.bosatsu.rankn.Type -import org.typelevel.paiges.{Doc, Document } +import org.typelevel.paiges.{Doc, Document} import scala.collection.immutable.SortedSet import scala.util.hashing.MurmurHash3 @@ -20,13 +20,10 @@ sealed abstract class TypedExpr[+T] { self: Product => MurmurHash3.productHash(this) def tag: T - /** - * For any well typed expression, i.e. - * one that has already gone through type - * inference, we should be able to get a type - * for each expression - * - */ + + /** For any well typed expression, i.e. one that has already gone through type + * inference, we should be able to get a type for each expression + */ lazy val getType: Type = this match { case Generic(params, expr) => @@ -35,9 +32,9 @@ sealed abstract class TypedExpr[+T] { self: Product => tpe case AnnotatedLambda(args, res, _) => Type.Fun(args.map(_._2), res.getType) - case Local(_, tpe, _) => tpe + case Local(_, tpe, _) => tpe case Global(_, _, tpe, _) => tpe - case App(_, _, tpe, _) => tpe + case App(_, _, tpe, _) => tpe case Let(_, _, in, _, _) => in.getType case Literal(_, tpe, _) => @@ -57,30 +54,52 @@ sealed abstract class TypedExpr[+T] { self: Product => case Generic(params, expr) => val pstr = Doc.intercalate( Doc.comma + Doc.lineOrSpace, - params.toList.map { case (p, k) => Doc.text(p.name) + Doc.text(": ") + k.toDoc } + params.toList.map { case (p, k) => + Doc.text(p.name) + Doc.text(": ") + k.toDoc + } ) - (Doc.text("(generic") + Doc.lineOrSpace + Doc.char('[') + pstr + Doc.char(']') + Doc.lineOrSpace + loop(expr) + Doc.char(')')).nested(4) + (Doc.text("(generic") + Doc.lineOrSpace + Doc.char('[') + pstr + Doc + .char(']') + Doc.lineOrSpace + loop(expr) + Doc.char(')')).nested(4) case Annotation(expr, tpe) => - (Doc.text("(ann") + Doc.lineOrSpace + rept(tpe) + Doc.lineOrSpace + loop(expr) + Doc.char(')')).nested(4) + (Doc.text("(ann") + Doc.lineOrSpace + rept( + tpe + ) + Doc.lineOrSpace + loop(expr) + Doc.char(')')).nested(4) case AnnotatedLambda(args, res, _) => (Doc.text("(lambda") + Doc.lineOrSpace + ( - Doc.char('[') + Doc.intercalate(Doc.lineOrSpace, args.toList.map { case (arg, tpe) => + Doc.char('[') + Doc.intercalate( + Doc.lineOrSpace, + args.toList.map { case (arg, tpe) => Doc.text(arg.sourceCodeRepr) + Doc.lineOrSpace + rept(tpe) - }) + Doc.char(']') - ) + Doc.lineOrSpace + loop(res) + Doc.char(')')).nested(4) + } + ) + Doc.char(']') + ) + Doc.lineOrSpace + loop(res) + Doc.char(')')).nested(4) case Local(v, tpe, _) => - (Doc.text("(var") + Doc.lineOrSpace + Doc.text(v.sourceCodeRepr) + Doc.lineOrSpace + rept(tpe) + Doc.char(')')).nested(4) + (Doc.text("(var") + Doc.lineOrSpace + Doc.text( + v.sourceCodeRepr + ) + Doc.lineOrSpace + rept(tpe) + Doc.char(')')).nested(4) case Global(p, v, tpe, _) => val pstr = Doc.text(p.asString + "::" + v.sourceCodeRepr) - (Doc.text("(var") + Doc.lineOrSpace + pstr + Doc.lineOrSpace + rept(tpe) + Doc.char(')')).nested(4) + (Doc.text("(var") + Doc.lineOrSpace + pstr + Doc.lineOrSpace + rept( + tpe + ) + Doc.char(')')).nested(4) case App(fn, args, tpe, _) => val argsDoc = Doc.intercalate(Doc.lineOrSpace, args.toList.map(loop)) - (Doc.text("(ap") + Doc.lineOrSpace + loop(fn) + Doc.lineOrSpace + argsDoc + Doc.lineOrSpace + rept(tpe) + Doc.char(')')).nested(4) + (Doc.text("(ap") + Doc.lineOrSpace + loop( + fn + ) + Doc.lineOrSpace + argsDoc + Doc.lineOrSpace + rept(tpe) + Doc + .char(')')).nested(4) case Let(n, b, in, rec, _) => - val nm = if (rec.isRecursive) Doc.text("(letrec") else Doc.text("(let") - (nm + Doc.lineOrSpace + Doc.text(n.sourceCodeRepr) + Doc.lineOrSpace + loop(b) + Doc.lineOrSpace + loop(in) + Doc.char(')')).nested(4) + val nm = + if (rec.isRecursive) Doc.text("(letrec") else Doc.text("(let") + (nm + Doc.lineOrSpace + Doc.text( + n.sourceCodeRepr + ) + Doc.lineOrSpace + loop(b) + Doc.lineOrSpace + loop(in) + Doc.char( + ')' + )).nested(4) case Literal(v, tpe, _) => - (Doc.text("(lit") + Doc.lineOrSpace + Doc.text(v.repr) + Doc.lineOrSpace + rept(tpe) + Doc.char(')')).nested(4) + (Doc.text("(lit") + Doc.lineOrSpace + Doc.text( + v.repr + ) + Doc.lineOrSpace + rept(tpe) + Doc.char(')')).nested(4) case Match(arg, branches, _) => implicit val docType: Document[Type] = Document.instance { tpe => rept(tpe) } @@ -88,20 +107,25 @@ sealed abstract class TypedExpr[+T] { self: Product => def pat(p: Pattern[(PackageName, Constructor), Type]): Doc = cpat.document(p) - val bstr = branches.toList.map { case (p, t) => (Doc.char('[') + pat(p) + Doc.comma + Doc.lineOrSpace + loop(t) + Doc.char(']')).nested(4) } - (Doc.text("(match") + Doc.lineOrSpace + loop(arg) + Doc.lineOrSpace + Doc.intercalate(Doc.lineOrSpace, bstr).nested(4) + Doc.char(')')).nested(4) + val bstr = branches.toList.map { case (p, t) => + (Doc.char('[') + pat(p) + Doc.comma + Doc.lineOrSpace + loop( + t + ) + Doc.char(']')).nested(4) + } + (Doc.text("(match") + Doc.lineOrSpace + loop( + arg + ) + Doc.lineOrSpace + Doc + .intercalate(Doc.lineOrSpace, bstr) + .nested(4) + Doc.char(')')).nested(4) } } loop(this).renderTrim(100) } - - /** - * All the free variables in this expression in order - * encountered and with duplicates (to see how often - * they appear) - */ + /** All the free variables in this expression in order encountered and with + * duplicates (to see how often they appear) + */ lazy val freeVarsDup: List[Bindable] = this match { case Generic(_, expr) => @@ -122,8 +146,7 @@ sealed abstract class TypedExpr[+T] { self: Product => val argFree = if (rec.isRecursive) { TypedExpr.filterNot(argFree0)(_ === arg) - } - else argFree0 + } else argFree0 argFree ::: (TypedExpr.filterNot(in.freeVarsDup)(_ === arg)) case Literal(_, _, _) => @@ -139,8 +162,7 @@ sealed abstract class TypedExpr[+T] { self: Product => else TypedExpr.filterNot(bfree)(newBinds) } // we can only take one branch, so count the max on each branch: - val branchFreeMax = branchFrees - .zipWithIndex + val branchFreeMax = branchFrees.zipWithIndex .flatMap { case (names, br) => names.map((_, br)) } // these groupBys are okay because we sort at the end .groupBy(identity) // group-by-name x branch @@ -174,53 +196,86 @@ object TypedExpr { else (h :: t1) // we only allocate here } - type Rho[A] = TypedExpr[A] // an expression with a Rho type (no top level forall) + type Rho[A] = + TypedExpr[A] // an expression with a Rho type (no top level forall) sealed abstract class Name[A] extends TypedExpr[A] with Product - /** - * This says that the resulting term is generic on a given param - * - * The paper says to add TyLam and TyApp nodes, but it never mentions what to do with them - */ - case class Generic[T](typeVars: NonEmptyList[(Type.Var.Bound, Kind)], in: TypedExpr[T]) extends TypedExpr[T] { + + /** This says that the resulting term is generic on a given param + * + * The paper says to add TyLam and TyApp nodes, but it never mentions what to + * do with them + */ + case class Generic[T]( + typeVars: NonEmptyList[(Type.Var.Bound, Kind)], + in: TypedExpr[T] + ) extends TypedExpr[T] { def tag: T = in.tag } // Annotation really means "widen", the term has a type that is a subtype of coerce, so we are widening // to the given type. This happens on Locals/Globals also in their tpe - case class Annotation[T](term: TypedExpr[T], coerce: Type) extends TypedExpr[T] { + case class Annotation[T](term: TypedExpr[T], coerce: Type) + extends TypedExpr[T] { def tag: T = term.tag } - case class AnnotatedLambda[T](args: NonEmptyList[(Bindable, Type)], expr: TypedExpr[T], tag: T) extends TypedExpr[T] + case class AnnotatedLambda[T]( + args: NonEmptyList[(Bindable, Type)], + expr: TypedExpr[T], + tag: T + ) extends TypedExpr[T] case class Local[T](name: Bindable, tpe: Type, tag: T) extends Name[T] - case class Global[T](pack: PackageName, name: Identifier, tpe: Type, tag: T) extends Name[T] - case class App[T](fn: TypedExpr[T], args: NonEmptyList[TypedExpr[T]], result: Type, tag: T) extends TypedExpr[T] - case class Let[T](arg: Bindable, expr: TypedExpr[T], in: TypedExpr[T], recursive: RecursionKind, tag: T) extends TypedExpr[T] + case class Global[T](pack: PackageName, name: Identifier, tpe: Type, tag: T) + extends Name[T] + case class App[T]( + fn: TypedExpr[T], + args: NonEmptyList[TypedExpr[T]], + result: Type, + tag: T + ) extends TypedExpr[T] + case class Let[T]( + arg: Bindable, + expr: TypedExpr[T], + in: TypedExpr[T], + recursive: RecursionKind, + tag: T + ) extends TypedExpr[T] // TODO, this shouldn't have a type, we know the type from Lit currently case class Literal[T](lit: Lit, tpe: Type, tag: T) extends TypedExpr[T] - case class Match[T](arg: TypedExpr[T], branches: NonEmptyList[(Pattern[(PackageName, Constructor), Type], TypedExpr[T])], tag: T) extends TypedExpr[T] - - def letAllNonRec[T](binds: NonEmptyList[(Bindable, TypedExpr[T])], in: TypedExpr[T], tag: T): Let[T] = { + case class Match[T]( + arg: TypedExpr[T], + branches: NonEmptyList[ + (Pattern[(PackageName, Constructor), Type], TypedExpr[T]) + ], + tag: T + ) extends TypedExpr[T] + + def letAllNonRec[T]( + binds: NonEmptyList[(Bindable, TypedExpr[T])], + in: TypedExpr[T], + tag: T + ): Let[T] = { val in1 = binds.tail match { - case Nil => in + case Nil => in case h1 :: t1 => letAllNonRec(NonEmptyList(h1, t1), in, tag) } val (n, ne) = binds.head Let(n, ne, in1, RecursionKind.NonRecursive, tag) } - /** - * If we expect expr to be a lambda of the given arity, return - * the parameter names and types and the rest of the body - */ - def toArgsBody[A](arity: Int, expr: TypedExpr[A]): Option[(NonEmptyList[(Bindable, Type)], TypedExpr[A])] = + /** If we expect expr to be a lambda of the given arity, return the parameter + * names and types and the rest of the body + */ + def toArgsBody[A]( + arity: Int, + expr: TypedExpr[A] + ): Option[(NonEmptyList[(Bindable, Type)], TypedExpr[A])] = expr match { - case Generic(_, e) => toArgsBody(arity, e) + case Generic(_, e) => toArgsBody(arity, e) case Annotation(e, _) => toArgsBody(arity, e) case AnnotatedLambda(args, expr, _) => if (args.length == arity) { Some((args, expr)) - } - else { + } else { None } case Let(arg, e, in, r, t) => @@ -232,8 +287,7 @@ object TypedExpr { // can't lift, we could alpha-rename to // deal with this case None - } - else { + } else { // push it down: Some((args, Let(arg, e, body, r, t))) } @@ -247,8 +301,7 @@ object TypedExpr { // can't lift, we could alpha-rename to // deal with this case None - } - else { + } else { Some((n, (p, b1))) } } @@ -257,8 +310,7 @@ object TypedExpr { argSetO.flatMap { argSet => if (argSet.map(_._1).toList.toSet.size == 1) { Some((argSet.head._1, Match(arg, argSet.map(_._2), tag))) - } - else { + } else { None } } @@ -267,19 +319,22 @@ object TypedExpr { implicit class InvariantTypedExpr[A](val self: TypedExpr[A]) extends AnyVal { def allTypes: SortedSet[Type] = - traverseType { t => Writer[SortedSet[Type], Type](SortedSet(t), t) }.run._1 + traverseType { t => + Writer[SortedSet[Type], Type](SortedSet(t), t) + }.run._1 - /** - * Traverse all the *non-shadowed* types inside the TypedExpr - */ + /** Traverse all the *non-shadowed* types inside the TypedExpr + */ def traverseType[F[_]: Applicative](fn: Type => F[Type]): F[TypedExpr[A]] = self match { - case gen@Generic(params, expr) => + case gen @ Generic(params, expr) => // params shadow below, so they are not free values // and can easily create bugs if passed into fn - val shadowed: Set[Type.Var.Bound] = params.toList.iterator.map(_._1).toSet + val shadowed: Set[Type.Var.Bound] = + params.toList.iterator.map(_._1).toSet val shadowFn: Type => F[Type] = { - case tvar@Type.TyVar(v: Type.Var.Bound) if shadowed(v) => Applicative[F].pure(tvar) + case tvar @ Type.TyVar(v: Type.Var.Bound) if shadowed(v) => + Applicative[F].pure(tvar) case notShadowed => fn(notShadowed) } @@ -288,19 +343,20 @@ object TypedExpr { .map(Generic(params, _)) case Annotation(of, tpe) => (of.traverseType(fn), fn(tpe)).mapN(Annotation(_, _)) - case lam@AnnotatedLambda(args, res, tag) => + case lam @ AnnotatedLambda(args, res, tag) => val a1 = args.traverse { case (n, t) => fn(t).map(n -> _) } fn(lam.getType) *> (a1, res.traverseType(fn)).mapN { - AnnotatedLambda( _, _, tag) + AnnotatedLambda(_, _, tag) } case Local(v, tpe, tag) => fn(tpe).map(Local(v, _, tag)) case Global(p, v, tpe, tag) => fn(tpe).map(Global(p, v, _, tag)) case App(f, args, tpe, tag) => - (f.traverseType(fn), args.traverse(_.traverseType(fn)), fn(tpe)).mapN { - App(_, _, _, tag) - } + (f.traverseType(fn), args.traverse(_.traverseType(fn)), fn(tpe)) + .mapN { + App(_, _, _, tag) + } case Let(v, exp, in, rec, tag) => (exp.traverseType(fn), in.traverseType(fn)).mapN { Let(v, _, _, rec, tag) @@ -309,18 +365,18 @@ object TypedExpr { fn(tpe).map(Literal(lit, _, tag)) case Match(expr, branches, tag) => // all branches have the same type: - val tbranch = branches.traverse { - case (p, t) => - p.traverseType(fn).product(t.traverseType(fn)) + val tbranch = branches.traverse { case (p, t) => + p.traverseType(fn).product(t.traverseType(fn)) } (expr.traverseType(fn), tbranch).mapN(Match(_, _, tag)) } - /** - * This applies fn on all the contained types, replaces the elements, then calls on the - * resulting. This is "bottom up" - */ - def traverseUp[F[_]: Monad](fn: TypedExpr[A] => F[TypedExpr[A]]): F[TypedExpr[A]] = { + /** This applies fn on all the contained types, replaces the elements, then + * calls on the resulting. This is "bottom up" + */ + def traverseUp[F[_]: Monad]( + fn: TypedExpr[A] => F[TypedExpr[A]] + ): F[TypedExpr[A]] = { // be careful not to mistake loop with fn def loop(te: TypedExpr[A]): F[TypedExpr[A]] = te.traverseUp(fn) @@ -337,8 +393,8 @@ object TypedExpr { loop(res).flatMap { res1 => fn(AnnotatedLambda(args, res1, tag)) } - case v@(Global(_, _, _, _) | Local(_, _, _) | Literal(_, _, _)) => - fn(v) + case v @ (Global(_, _, _, _) | Local(_, _, _) | Literal(_, _, _)) => + fn(v) case App(f, args, tpe, tag) => (loop(f), args.traverse(loop(_))) .mapN(App(_, _, tpe, tag)) @@ -348,8 +404,8 @@ object TypedExpr { .mapN(Let(v, _, _, rec, tag)) .flatMap(fn) case Match(expr, branches, tag) => - val tbranch = branches.traverse { - case (p, t) => loop(t).map((p, _)) + val tbranch = branches.traverse { case (p, t) => + loop(t).map((p, _)) } (loop(expr), tbranch) .mapN(Match(_, _, tag)) @@ -357,52 +413,64 @@ object TypedExpr { } } - /** - * Here are all the global names inside this expression - */ + /** Here are all the global names inside this expression + */ def globals: Set[(PackageName, Identifier)] = traverseUp[Writer[Set[(PackageName, Identifier)], *]] { - case g @ Global(p, i, _, _) => Writer.tell(Set[(PackageName, Identifier)]((p, i))).as(g) + case g @ Global(p, i, _, _) => + Writer.tell(Set[(PackageName, Identifier)]((p, i))).as(g) case notG => Monad[Writer[Set[(PackageName, Identifier)], *]].pure(notG) - } - .written + }.written } - def zonkMeta[F[_]: Applicative, A](te: TypedExpr[A])(fn: Type.Meta => F[Option[Type.Rho]]): F[TypedExpr[A]] = + def zonkMeta[F[_]: Applicative, A](te: TypedExpr[A])( + fn: Type.Meta => F[Option[Type.Rho]] + ): F[TypedExpr[A]] = te.traverseType(Type.zonkMeta(_)(fn)) - /** - * quantify every meta variable that is not escaped into - * the outer environment. - * - * TODO: This can probably be optimized. I think it is currently - * quadradic in depth of the TypedExpr - */ + /** quantify every meta variable that is not escaped into the outer + * environment. + * + * TODO: This can probably be optimized. I think it is currently quadradic in + * depth of the TypedExpr + */ def quantify[F[_]: Monad, A]( - env: Map[(Option[PackageName], Identifier), Type], - rho: TypedExpr.Rho[A], - zFn: Type.Meta => F[Option[Type.Rho]], - writeFn: (Type.Meta, Type.Var) => F[Unit]): F[TypedExpr[A]] = { + env: Map[(Option[PackageName], Identifier), Type], + rho: TypedExpr.Rho[A], + zFn: Type.Meta => F[Option[Type.Rho]], + writeFn: (Type.Meta, Type.Var) => F[Unit] + ): F[TypedExpr[A]] = { // we need to zonk before we get going because // some of the meta-variables may point to the same values def getMetaTyVars(tpes: List[Type]): F[SortedSet[Type.Meta]] = tpes.traverse(Type.zonkMeta(_)(zFn)).map(Type.metaTvs(_)) - def quantify0(forAlls: List[Type.Meta], rho: TypedExpr.Rho[A]): F[TypedExpr[A]] = + def quantify0( + forAlls: List[Type.Meta], + rho: TypedExpr.Rho[A] + ): F[TypedExpr[A]] = NonEmptyList.fromList(forAlls) match { case None => Applicative[F].pure(rho) case Some(metas) => val used: Set[Type.Var.Bound] = Type.tyVarBinders(rho.getType :: Nil) val aligned = Type.alignBinders(metas, used) - val bound = aligned.traverse { case (m, n) => writeFn(m, n).as((n, m.kind)) } + val bound = aligned.traverse { case (m, n) => + writeFn(m, n).as((n, m.kind)) + } // we only need to zonk after doing a write: - (bound, zonkMeta(rho)(zFn)).mapN { (typeArgs, r) => forAll(typeArgs, r) } + (bound, zonkMeta(rho)(zFn)).mapN { (typeArgs, r) => + forAll(typeArgs, r) + } } type Name = (Option[PackageName], Identifier) - def quantifyMetas(env: Map[Name, Type], metas: SortedSet[Type.Meta], te: TypedExpr[A]): F[TypedExpr[A]] = + def quantifyMetas( + env: Map[Name, Type], + metas: SortedSet[Type.Meta], + te: TypedExpr[A] + ): F[TypedExpr[A]] = if (metas.isEmpty) Applicative[F].pure(te) else { for { @@ -436,7 +504,9 @@ object TypedExpr { Annotation(t1, coerce) } case AnnotatedLambda(args, expr, tag) => - val env1 = env ++ args.iterator.map { case (arg, tpe) => ((None, arg)) -> tpe } + val env1 = env ++ args.iterator.map { case (arg, tpe) => + ((None, arg)) -> tpe + } deepQuantify(env1, expr) .map { e1 => lambda(args, e1, tag) @@ -474,13 +544,19 @@ object TypedExpr { * which has a type forall a. Int which is the same * as Int */ - type Branch = (Pattern[(PackageName, Constructor), Type], TypedExpr[A]) + type Branch = + (Pattern[(PackageName, Constructor), Type], TypedExpr[A]) def allTypes[X](p: Pattern[X, Type]): SortedSet[Type] = - p.traverseType { t => Writer[SortedSet[Type], Type](SortedSet(t), t) }.run._1 + p.traverseType { t => + Writer[SortedSet[Type], Type](SortedSet(t), t) + }.run + ._1 val allMatchMetas: F[SortedSet[Type.Meta]] = - getMetaTyVars(arg.getType :: branches.foldMap { case (p, _) => allTypes(p) }.toList) + getMetaTyVars(arg.getType :: branches.foldMap { case (p, _) => + allTypes(p) + }.toList) def handleBranch(br: Branch): F[Branch] = { val (p, expr) = br @@ -503,126 +579,154 @@ object TypedExpr { finish(expr).map(forAll(ps, _)) case unreach => // $COVERAGE-OFF$ - sys.error(s"Match quantification yielded neither Generic nor Match: $unreach") - // $COVERAGE-ON$ + sys.error( + s"Match quantification yielded neither Generic nor Match: $unreach" + ) + // $COVERAGE-ON$ } noArg.flatMap(finish) - case nonest@(Global(_, _, _, _) | Local(_, _, _) | Literal(_, _, _)) => + case nonest @ (Global(_, _, _, _) | Local(_, _, _) | + Literal(_, _, _)) => Applicative[F].pure(nonest) } deepQuantify(env, rho) } - implicit val traverseTypedExpr: Traverse[TypedExpr] = new Traverse[TypedExpr] { - def traverse[F[_]: Applicative, T, S](typedExprT: TypedExpr[T])(fn: T => F[S]): F[TypedExpr[S]] = - typedExprT match { - case Generic(params, expr) => - expr.traverse(fn).map(Generic(params, _)) - case Annotation(of, tpe) => - of.traverse(fn).map(Annotation(_, tpe)) - case AnnotatedLambda(args, res, tag) => - (res.traverse(fn), fn(tag)).mapN { - AnnotatedLambda(args, _, _) - } - case Local(v, tpe, tag) => - fn(tag).map(Local(v, tpe, _)) - case Global(p, v, tpe, tag) => - fn(tag).map(Global(p, v, tpe, _)) - case App(f, args, tpe, tag) => - (f.traverse(fn), args.traverse(_.traverse(fn)), fn(tag)).mapN { - App(_, _, tpe, _) - } - case Let(v, exp, in, rec, tag) => - (exp.traverse(fn), in.traverse(fn), fn(tag)).mapN { - Let(v, _, _, rec, _) - } - case Literal(lit, tpe, tag) => - fn(tag).map(Literal(lit, tpe, _)) - case Match(expr, branches, tag) => - // all branches have the same type: - val tbranch = branches.traverse { - case (p, t) => + implicit val traverseTypedExpr: Traverse[TypedExpr] = + new Traverse[TypedExpr] { + def traverse[F[_]: Applicative, T, S]( + typedExprT: TypedExpr[T] + )(fn: T => F[S]): F[TypedExpr[S]] = + typedExprT match { + case Generic(params, expr) => + expr.traverse(fn).map(Generic(params, _)) + case Annotation(of, tpe) => + of.traverse(fn).map(Annotation(_, tpe)) + case AnnotatedLambda(args, res, tag) => + (res.traverse(fn), fn(tag)).mapN { + AnnotatedLambda(args, _, _) + } + case Local(v, tpe, tag) => + fn(tag).map(Local(v, tpe, _)) + case Global(p, v, tpe, tag) => + fn(tag).map(Global(p, v, tpe, _)) + case App(f, args, tpe, tag) => + (f.traverse(fn), args.traverse(_.traverse(fn)), fn(tag)).mapN { + App(_, _, tpe, _) + } + case Let(v, exp, in, rec, tag) => + (exp.traverse(fn), in.traverse(fn), fn(tag)).mapN { + Let(v, _, _, rec, _) + } + case Literal(lit, tpe, tag) => + fn(tag).map(Literal(lit, tpe, _)) + case Match(expr, branches, tag) => + // all branches have the same type: + val tbranch = branches.traverse { case (p, t) => t.traverse(fn).map((p, _)) - } - (expr.traverse(fn), tbranch, fn(tag)).mapN(Match(_, _, _)) - } + } + (expr.traverse(fn), tbranch, fn(tag)).mapN(Match(_, _, _)) + } - def foldLeft[A, B](typedExprA: TypedExpr[A], b: B)(f: (B, A) => B): B = typedExprA match { - case Generic(_, e) => - foldLeft(e, b)(f) - case Annotation(e, _) => - foldLeft(e, b)(f) - case AnnotatedLambda(_, e, tag) => - val b1 = foldLeft(e, b)(f) - f(b1, tag) - case n: Name[A] => f(b, n.tag) - case App(fn, args, _, tag) => - val b1 = foldLeft(fn, b)(f) - val b2 = args.foldLeft(b1)((b1, a) => foldLeft(a, b1)(f)) - f(b2, tag) - case Let(_, exp, in, _, tag) => - val b1 = foldLeft(exp, b)(f) - val b2 = foldLeft(in, b1)(f) - f(b2, tag) - case Literal(_, _, tag) => - f(b, tag) - case Match(arg, branches, tag) => - val b1 = foldLeft(arg, b)(f) - val b2 = branches.foldLeft(b1) { case (bn, (_, t)) => foldLeft(t, bn)(f) } - f(b2, tag) - } + def foldLeft[A, B](typedExprA: TypedExpr[A], b: B)(f: (B, A) => B): B = + typedExprA match { + case Generic(_, e) => + foldLeft(e, b)(f) + case Annotation(e, _) => + foldLeft(e, b)(f) + case AnnotatedLambda(_, e, tag) => + val b1 = foldLeft(e, b)(f) + f(b1, tag) + case n: Name[A] => f(b, n.tag) + case App(fn, args, _, tag) => + val b1 = foldLeft(fn, b)(f) + val b2 = args.foldLeft(b1)((b1, a) => foldLeft(a, b1)(f)) + f(b2, tag) + case Let(_, exp, in, _, tag) => + val b1 = foldLeft(exp, b)(f) + val b2 = foldLeft(in, b1)(f) + f(b2, tag) + case Literal(_, _, tag) => + f(b, tag) + case Match(arg, branches, tag) => + val b1 = foldLeft(arg, b)(f) + val b2 = branches.foldLeft(b1) { case (bn, (_, t)) => + foldLeft(t, bn)(f) + } + f(b2, tag) + } - def foldRight[A, B](typedExprA: TypedExpr[A], lb: Eval[B])(f: (A, Eval[B]) => Eval[B]): Eval[B] = typedExprA match { - case Generic(_, e) => - foldRight(e, lb)(f) - case Annotation(e, _) => - foldRight(e, lb)(f) - case AnnotatedLambda(_, e, tag) => - val lb1 = f(tag, lb) - foldRight(e, lb1)(f) - case n: Name[A] => f(n.tag, lb) - case App(fn, args, _, tag) => - val b1 = f(tag, lb) - val b2 = args.toList.foldRight(b1)((a, b1) => foldRight(a, b1)(f)) - foldRight(fn, b2)(f) - case Let(_, exp, in, _, tag) => - val b1 = f(tag, lb) - val b2 = foldRight(in, b1)(f) - foldRight(exp, b2)(f) - case Literal(_, _, tag) => - f(tag, lb) - case Match(arg, branches, tag) => - val b1 = f(tag, lb) - val b2 = branches.foldRight(b1) { case ((_, t), bn) => foldRight(t, bn)(f) } - foldRight(arg, b2)(f) - } + def foldRight[A, B](typedExprA: TypedExpr[A], lb: Eval[B])( + f: (A, Eval[B]) => Eval[B] + ): Eval[B] = typedExprA match { + case Generic(_, e) => + foldRight(e, lb)(f) + case Annotation(e, _) => + foldRight(e, lb)(f) + case AnnotatedLambda(_, e, tag) => + val lb1 = f(tag, lb) + foldRight(e, lb1)(f) + case n: Name[A] => f(n.tag, lb) + case App(fn, args, _, tag) => + val b1 = f(tag, lb) + val b2 = args.toList.foldRight(b1)((a, b1) => foldRight(a, b1)(f)) + foldRight(fn, b2)(f) + case Let(_, exp, in, _, tag) => + val b1 = f(tag, lb) + val b2 = foldRight(in, b1)(f) + foldRight(exp, b2)(f) + case Literal(_, _, tag) => + f(tag, lb) + case Match(arg, branches, tag) => + val b1 = f(tag, lb) + val b2 = branches.foldRight(b1) { case ((_, t), bn) => + foldRight(t, bn)(f) + } + foldRight(arg, b2)(f) + } - override def map[A, B](te: TypedExpr[A])(fn: A => B): TypedExpr[B] = te match { - case Generic(tv, in) => Generic(tv, map(in)(fn)) - case Annotation(term, tpe) => Annotation(map(term)(fn), tpe) - case AnnotatedLambda(args, expr, tag) => AnnotatedLambda(args, map(expr)(fn), fn(tag)) - case l@Local(_, _, _) => l.copy(tag = fn(l.tag)) - case g@Global(_, _, _, _) => g.copy(tag = fn(g.tag)) - case App(fnT, args, tpe, tag) => App(map(fnT)(fn), args.map(map(_)(fn)), tpe, fn(tag)) - case Let(b, e, in, r, t) => Let(b, map(e)(fn), map(in)(fn), r, fn(t)) - case lit@Literal(_, _, _) => lit.copy(tag = fn(lit.tag)) - case Match(arg, branches, tag) => - Match(map(arg)(fn), branches.map { case (p, t) => (p, map(t)(fn)) }, fn(tag)) + override def map[A, B](te: TypedExpr[A])(fn: A => B): TypedExpr[B] = + te match { + case Generic(tv, in) => Generic(tv, map(in)(fn)) + case Annotation(term, tpe) => Annotation(map(term)(fn), tpe) + case AnnotatedLambda(args, expr, tag) => + AnnotatedLambda(args, map(expr)(fn), fn(tag)) + case l @ Local(_, _, _) => l.copy(tag = fn(l.tag)) + case g @ Global(_, _, _, _) => g.copy(tag = fn(g.tag)) + case App(fnT, args, tpe, tag) => + App(map(fnT)(fn), args.map(map(_)(fn)), tpe, fn(tag)) + case Let(b, e, in, r, t) => Let(b, map(e)(fn), map(in)(fn), r, fn(t)) + case lit @ Literal(_, _, _) => lit.copy(tag = fn(lit.tag)) + case Match(arg, branches, tag) => + Match( + map(arg)(fn), + branches.map { case (p, t) => (p, map(t)(fn)) }, + fn(tag) + ) + } } - } type Coerce = FunctionK[TypedExpr, TypedExpr] // We know initTpe <:< instTpe, we may be able to simply // fix some of the universally quantified variables - private def instantiateTo[A](gen: Generic[A], instTpe: Type.Rho, kinds: Type => Option[Kind]): Option[TypedExpr[A]] = + private def instantiateTo[A]( + gen: Generic[A], + instTpe: Type.Rho, + kinds: Type => Option[Kind] + ): Option[TypedExpr[A]] = gen.getType match { case Type.ForAll(bs, in) => import Type._ - def solve(left: Type, right: Type, state: Map[Type.Var, Type], solveSet: Set[Type.Var]): Option[Map[Type.Var, Type]] = + def solve( + left: Type, + right: Type, + state: Map[Type.Var, Type], + solveSet: Set[Type.Var] + ): Option[Map[Type.Var, Type]] = (left, right) match { case (TyVar(v), right) if solveSet(v) => Some(state.updated(v, right)) @@ -643,8 +747,7 @@ object TypedExpr { if (left == right) { // can't recurse further into left Some(state) - } - else None + } else None case (TyApply(_, _), _) => None } @@ -658,7 +761,9 @@ object TypedExpr { } private def allPatternTypes[N](p: Pattern[N, Type]): SortedSet[Type] = - p.traverseType { t => Writer[SortedSet[Type], Type](SortedSet(t), t) }.run._1 + p.traverseType { t => Writer[SortedSet[Type], Type](SortedSet(t), t) } + .run + ._1 private def pushGeneric[A](g: Generic[A]): Option[TypedExpr[A]] = g.in match { @@ -666,23 +771,23 @@ object TypedExpr { val argFree = Type.freeBoundTyVars(args.toList.map(_._2)).toSet if (g.typeVars.exists { case (b, _) => argFree(b) }) { None - } - else { + } else { val gbody = Generic(g.typeVars, body) val pushedBody = pushGeneric(gbody).getOrElse(gbody) Some(AnnotatedLambda(args, pushedBody, a)) } // we can do the same thing on Match case Match(arg, branches, tag) => - val preTypes = arg.allTypes | branches.foldLeft(arg.allTypes) { case (ts, (p, _)) => ts | allPatternTypes(p) } + val preTypes = arg.allTypes | branches.foldLeft(arg.allTypes) { + case (ts, (p, _)) => ts | allPatternTypes(p) + } val argFree = Type.freeBoundTyVars(preTypes.toList).toSet if (g.typeVars.exists { case (b, _) => argFree(b) }) { None - } - else { + } else { // the only the branches have generics val b1 = branches.map { case (p, b) => - val gb = Generic(g.typeVars, b) + val gb = Generic(g.typeVars, b) val gb1 = pushGeneric(gb).getOrElse(gb) (p, gb1) } @@ -692,8 +797,7 @@ object TypedExpr { val argFree = Type.freeBoundTyVars(v.getType :: Nil).toSet if (g.typeVars.exists { case (b, _) => argFree(b) }) { None - } - else { + } else { val gin = Generic(g.typeVars, in) val gin1 = pushGeneric(gin).getOrElse(gin) Some(Let(b, v, gin1, rec, tag)) @@ -710,7 +814,7 @@ object TypedExpr { val cb = coerceRho(b, kinds) val cas = args.map { case aRho: Type.Rho => Some(coerceRho(aRho, kinds)) - case _ => None + case _ => None } coerceFn1(args, b, cas, cb, kinds) @@ -719,19 +823,20 @@ object TypedExpr { def apply[A](expr: TypedExpr[A]) = expr match { case _ if expr.getType.sameAs(tpe) => expr - case Annotation(t, _) => self(t) - case Local(_, _, _) | Global(_, _, _, _) | AnnotatedLambda(_, _, _)| Literal(_, _, _) => + case Annotation(t, _) => self(t) + case Local(_, _, _) | Global(_, _, _, _) | + AnnotatedLambda(_, _, _) | Literal(_, _, _) => // All of these are widened. The lambda seems like we should be able to do // better, but the type isn't a Fun(Type, Type.Rho)... this is probably unreachable for // the AnnotatedLambda Annotation(expr, tpe) - case gen@Generic(_, _) => + case gen @ Generic(_, _) => pushGeneric(gen) match { case Some(e1) => self(e1) case None => instantiateTo(gen, tpe, kinds) match { case Some(res) => res - case None => + case None => // TODO: this is basically giving up Annotation(gen, tpe) } @@ -739,7 +844,7 @@ object TypedExpr { case App(fn, aargs, _, tag) => fn match { case AnnotatedLambda(lamArgs, body, _) => - //(\xs - res)(ys) == let x1 = y1 in let x2 = y2 in ... res + // (\xs - res)(ys) == let x1 = y1 in let x2 = y2 in ... res val binds = lamArgs.zip(aargs).map { case ((n, rho: Type.Rho), arg) => (n, coerceRho(rho, kinds)(arg)) @@ -756,7 +861,13 @@ object TypedExpr { case (arg, nonRho) => (arg, nonRho, None) } - val fn1 = coerceFn1(cArgs.map(_._2), tpe, cArgs.map(_._3), self, kinds)(fn) + val fn1 = coerceFn1( + cArgs.map(_._2), + tpe, + cArgs.map(_._3), + self, + kinds + )(fn) App(fn1, cArgs.map(_._1), tpe, tag) case _ => // TODO, what should we do here? @@ -774,14 +885,17 @@ object TypedExpr { // TODO: this may be wrong. e.g. we could leaving meta in the types // embedded in patterns, this does not seem to happen since we would // error if metas escape typechecking - Match(arg, branches.map { case (p, expr) => (p, self(expr)) }, tag) + Match( + arg, + branches.map { case (p, expr) => (p, self(expr)) }, + tag + ) } } } - /** - * Return the list of the free vars - */ + /** Return the list of the free vars + */ def freeVars[A](ts: List[TypedExpr[A]]): List[Bindable] = freeVarsDup(ts).distinct @@ -791,16 +905,18 @@ object TypedExpr { private def freeVarsDup[A](ts: List[TypedExpr[A]]): List[Bindable] = ts.flatMap(_.freeVarsDup) - /** - * Try to substitute ex for ident in the expression: in - * - * This can fail if the free variables in ex are shadowed - * above ident in in. - * - * this code is very similar to Declaration.substitute - * if bugs are found in one, consult the other - */ - def substitute[A](ident: Bindable, ex: TypedExpr[A], in: TypedExpr[A]): Option[TypedExpr[A]] = { + /** Try to substitute ex for ident in the expression: in + * + * This can fail if the free variables in ex are shadowed above ident in in. + * + * this code is very similar to Declaration.substitute if bugs are found in + * one, consult the other + */ + def substitute[A]( + ident: Bindable, + ex: TypedExpr[A], + in: TypedExpr[A] + ): Option[TypedExpr[A]] = { // if we hit a shadow, we don't need to substitute down // that branch @inline def shadows(i: Bindable): Boolean = i === ident @@ -812,7 +928,7 @@ object TypedExpr { def loop(in: TypedExpr[A]): Option[TypedExpr[A]] = in match { - case Local(i, _, _) if i === ident => Some(ex) + case Local(i, _, _) if i === ident => Some(ex) case Global(_, _, _, _) | Local(_, _, _) | Literal(_, _, _) => Some(in) case Generic(a, expr) => loop(expr).map(Generic(a, _)) @@ -824,20 +940,19 @@ object TypedExpr { else loop(res).map(AnnotatedLambda(args, _, tag)) case App(fn, args, tpe, tag) => (loop(fn), args.traverse(loop(_))).mapN(App(_, _, tpe, tag)) - case let@Let(arg, argE, in, rec, tag) => + case let @ Let(arg, argE, in, rec, tag) => if (masks(arg)) None else if (shadows(arg)) { // recursive shadow blocks both argE and in if (rec.isRecursive) Some(let) else loop(argE).map(Let(arg, _, in, rec, tag)) - } - else { + } else { (loop(argE), loop(in)).mapN(Let(arg, _, _, rec, tag)) } case Match(arg, branches, tag) => // Maintain the order we encounter things: val arg1 = loop(arg) - val b1 = branches.traverse { case in@(p, b) => + val b1 = branches.traverse { case in @ (p, b) => // these are not free variables in this branch val ns = p.names if (ns.exists(masks)) None @@ -850,7 +965,10 @@ object TypedExpr { loop(in) } - def substituteTypeVar[A](typedExpr: TypedExpr[A], env: Map[Type.Var, Type]): TypedExpr[A] = + def substituteTypeVar[A]( + typedExpr: TypedExpr[A], + env: Map[Type.Var, Type] + ): TypedExpr[A] = typedExpr match { case Generic(params, expr) => // we need to remove the params which are shadowed below @@ -858,16 +976,15 @@ object TypedExpr { val env1 = env.iterator.filter { case (k, _) => !paramSet(k) }.toMap Generic(params, substituteTypeVar(expr, env1)) case Annotation(of, tpe) => - Annotation( - substituteTypeVar(of, env), - Type.substituteVar(tpe, env)) + Annotation(substituteTypeVar(of, env), Type.substituteVar(tpe, env)) case AnnotatedLambda(args, res, tag) => AnnotatedLambda( - args.map { case (n, tpe) => + args.map { case (n, tpe) => (n, Type.substituteVar(tpe, env)) }, substituteTypeVar(res, env), - tag) + tag + ) case Local(v, tpe, tag) => Local(v, Type.substituteVar(tpe, env), tag) case Global(p, v, tpe, tag) => @@ -877,45 +994,49 @@ object TypedExpr { substituteTypeVar(f, env), args.map(substituteTypeVar(_, env)), Type.substituteVar(tpe, env), - tag) + tag + ) case Let(v, exp, in, rec, tag) => Let( v, substituteTypeVar(exp, env), substituteTypeVar(in, env), rec, - tag) + tag + ) case Literal(lit, tpe, tag) => Literal(lit, Type.substituteVar(tpe, env), tag) case Match(expr, branches, tag) => - val branches1 = branches.map { - case (p, t) => - val p1 = p.mapType(Type.substituteVar(_, env)) - val t1 = substituteTypeVar(t, env) - (p1, t1) + val branches1 = branches.map { case (p, t) => + val p1 = p.mapType(Type.substituteVar(_, env)) + val t1 = substituteTypeVar(t, env) + (p1, t1) } val expr1 = substituteTypeVar(expr, env) Match(expr1, branches1, tag) } - private def replaceVarType[A](te: TypedExpr[A], name: Bindable, tpe: Type): TypedExpr[A] = { + private def replaceVarType[A]( + te: TypedExpr[A], + name: Bindable, + tpe: Type + ): TypedExpr[A] = { def recur(t: TypedExpr[A]) = replaceVarType(t, name, tpe) te match { - case Generic(tv, in) => Generic(tv, recur(in)) - case Annotation(term, tpe) => Annotation(recur(term), tpe) + case Generic(tv, in) => Generic(tv, recur(in)) + case Annotation(term, tpe) => Annotation(recur(term), tpe) case AnnotatedLambda(args, expr, tag) => // this is a kind of let: if (args.exists(_._1 == name)) { // we are shadowing, so we are done: te - } - else { + } else { // no shadow AnnotatedLambda(args, recur(expr), tag) } case Local(nm, _, tag) if nm == name => Local(name, tpe, tag) - case n: Name[A] => n + case n: Name[A] => n case App(fnT, args, tpe, tag) => App(recur(fnT), args.map(recur), tpe, tag) case Let(b, e, in, r, t) => @@ -929,51 +1050,59 @@ object TypedExpr { // but b does shadow inside `in` Let(b, recur(e), in, r, t) } - } - else Let(b, recur(e), recur(in), r, t) - case lit@Literal(_, _, _) => lit + } else Let(b, recur(e), recur(in), r, t) + case lit @ Literal(_, _, _) => lit case Match(arg, branches, tag) => Match(recur(arg), branches.map { case (p, t) => (p, recur(t)) }, tag) } } - /** - * TODO this seems pretty expensive to blindly apply: we are deoptimizing - * the nodes pretty heavily - */ - def coerceFn(args: NonEmptyList[Type], result: Type.Rho, coarg: NonEmptyList[Coerce], cores: Coerce, kinds: Type => Option[Kind]): Coerce = + /** TODO this seems pretty expensive to blindly apply: we are deoptimizing the + * nodes pretty heavily + */ + def coerceFn( + args: NonEmptyList[Type], + result: Type.Rho, + coarg: NonEmptyList[Coerce], + cores: Coerce, + kinds: Type => Option[Kind] + ): Coerce = coerceFn1(args, result, coarg.map(Some(_)), cores, kinds) - private def coerceFn1(arg: NonEmptyList[Type], result: Type.Rho, coargOpt: NonEmptyList[Option[Coerce]], cores: Coerce, kinds: Type => Option[Kind]): Coerce = + private def coerceFn1( + arg: NonEmptyList[Type], + result: Type.Rho, + coargOpt: NonEmptyList[Option[Coerce]], + cores: Coerce, + kinds: Type => Option[Kind] + ): Coerce = new FunctionK[TypedExpr, TypedExpr] { self => val fntpe = Type.Fun(arg, result) def apply[A](expr: TypedExpr[A]) = { expr match { - case _ if expr.getType.sameAs(fntpe) => expr - case Annotation(t, _) => self(t) + case _ if expr.getType.sameAs(fntpe) => expr + case Annotation(t, _) => self(t) case AnnotatedLambda(args0, res, tag) => // note, Var(None, name, originalType, tag) // is hanging out in res, or it is unused - val args1 = args0.zip(arg).map { - case ((n, _), t) => (n, t) + val args1 = args0.zip(arg).map { case ((n, _), t) => + (n, t) } - val res1 = args1 - .toList - .foldRight(res) { - case ((name, arg), res) => - replaceVarType(res, name, arg) + val res1 = args1.toList + .foldRight(res) { case ((name, arg), res) => + replaceVarType(res, name, arg) } AnnotatedLambda(args1, cores(res1), tag) - case gen@Generic(_, _) => + case gen @ Generic(_, _) => pushGeneric(gen) match { case Some(e1) => self(e1) case None => instantiateTo(gen, fntpe, kinds) match { case Some(res) => res - case None => Annotation(gen, fntpe) + case None => Annotation(gen, fntpe) } - } + } case Local(_, _, _) | Global(_, _, _, _) | Literal(_, _, _) => Annotation(expr, fntpe) case Let(arg, argE, in, rec, tag) => @@ -984,21 +1113,23 @@ object TypedExpr { // error if metas escape typechecking Match(arg, branches.map { case (p, expr) => (p, self(expr)) }, tag) case App(AnnotatedLambda(lamArgs, body, _), aArgs, _, tag) => - //(\x - res)(y) == let x = y in res + // (\x - res)(y) == let x = y in res val arg1 = lamArgs.zip(aArgs).map { case ((n, rho: Type.Rho), arg) => (n, coerceRho(rho, kinds)(arg)) - case ((n, _), arg) => (n, arg) + case ((n, _), arg) => (n, arg) } letAllNonRec(arg1, self(body), tag) case App(_, _, _, _) => /* - * We have to be careful not to collide with the free vars in expr - * TODO: it is unclear why we are doing this... it may have just been - * a cute trick in the original rankn types paper, but I'm not - * sure what is buying us. - */ + * We have to be careful not to collide with the free vars in expr + * TODO: it is unclear why we are doing this... it may have just been + * a cute trick in the original rankn types paper, but I'm not + * sure what is buying us. + */ val free = freeVarsSet(expr :: Nil) - val nameGen = Type.allBinders.iterator.map { v => Identifier.Name(v.name) }.filterNot(free) + val nameGen = Type.allBinders.iterator + .map { v => Identifier.Name(v.name) } + .filterNot(free) val lamArgs = arg.map { t => (nameGen.next(), t) } val aArgs = lamArgs.map { case (n, t) => Local(n, t, expr.tag) } // name -> (expr((name: arg)): result) @@ -1008,7 +1139,10 @@ object TypedExpr { } } - def forAll[A](params: NonEmptyList[(Type.Var.Bound, Kind)], expr: TypedExpr[A]): TypedExpr[A] = + def forAll[A]( + params: NonEmptyList[(Type.Var.Bound, Kind)], + expr: TypedExpr[A] + ): TypedExpr[A] = expr match { case Generic(ps, ex0) => // if params and ps have duplicates, that @@ -1018,7 +1152,7 @@ object TypedExpr { val innerSet = ps.toList.iterator.map(_._1).toSet val newParams = params.toList.filterNot { case (v, _) => innerSet(v) } val ps1 = NonEmptyList.fromList(newParams) match { - case None => ps + case None => ps case Some(nep) => nep ::: ps } forAll(ps1, ex0) @@ -1028,11 +1162,15 @@ object TypedExpr { // we not uncommonly add an annotation just to make a generic wrapper to get back where // we were case Annotation(term, _) if g.getType.sameAs(term.getType) => term - case _ => g + case _ => g } } - private def lambda[A](args: NonEmptyList[(Bindable, Type)], expr: TypedExpr[A], tag: A): TypedExpr[A] = + private def lambda[A]( + args: NonEmptyList[(Bindable, Type)], + expr: TypedExpr[A], + tag: A + ): TypedExpr[A] = expr match { // TODO: this branch is never exercised. There is probably some reason for that // that the types/invariants are losing @@ -1044,13 +1182,11 @@ object TypedExpr { val collisions = frees.intersect(quants) if (collisions.isEmpty) { Generic(ps, AnnotatedLambda(args, ex0, tag)) - } - else { - // don't replace with any existing type variable or any of the free variables + } else { + // don't replace with any existing type variable or any of the free variables val replacements = Type.allBinders.iterator.filterNot(quants | frees) val repMap: Map[Type.Var.Bound, Type.Var.Bound] = - collisions - .iterator + collisions.iterator .zip(replacements) .toMap @@ -1058,8 +1194,7 @@ object TypedExpr { val typeMap: Map[Type.Var, Type] = repMap.iterator.map { case (k, v) => (k, Type.TyVar(v)) }.toMap val ex1 = substituteTypeVar(ex0, typeMap) - Generic(ps1, - AnnotatedLambda(args, ex1, tag)) + Generic(ps1, AnnotatedLambda(args, ex1, tag)) } case notGen => AnnotatedLambda(args, notGen, tag) diff --git a/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala b/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala index de199cd3a..5a0d8fba2 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala @@ -9,38 +9,69 @@ import Identifier.{Bindable, Constructor} object TypedExprNormalization { import TypedExpr._ - type ScopeT[A, S] = Map[(Option[PackageName], Bindable), (RecursionKind, TypedExpr[A], S)] + type ScopeT[A, S] = + Map[(Option[PackageName], Bindable), (RecursionKind, TypedExpr[A], S)] type Scope[A] = FixType.Fix[ScopeT[A, *]] def emptyScope[A]: Scope[A] = FixType.fix[ScopeT[A, *]](Map.empty) implicit final class ScopeOps[A](private val scope: Scope[A]) extends AnyVal { - def updated(key: Bindable, value: (RecursionKind, TypedExpr[A], Scope[A])): Scope[A] = - FixType.fix[ScopeT[A, *]](FixType.unfix[ScopeT[A, *]](scope).updated((None, key), value)) - - def updatedGlobal(pack: PackageName, key: Bindable, value: (RecursionKind, TypedExpr[A], Scope[A])): Scope[A] = - FixType.fix[ScopeT[A, *]](FixType.unfix[ScopeT[A, *]](scope).updated((Some(pack), key), value)) + def updated( + key: Bindable, + value: (RecursionKind, TypedExpr[A], Scope[A]) + ): Scope[A] = + FixType.fix[ScopeT[A, *]]( + FixType.unfix[ScopeT[A, *]](scope).updated((None, key), value) + ) + + def updatedGlobal( + pack: PackageName, + key: Bindable, + value: (RecursionKind, TypedExpr[A], Scope[A]) + ): Scope[A] = + FixType.fix[ScopeT[A, *]]( + FixType.unfix[ScopeT[A, *]](scope).updated((Some(pack), key), value) + ) def -(key: Bindable): Scope[A] = - FixType.fix[ScopeT[A, *]](FixType.unfix[ScopeT[A, *]](scope) - (None -> key)) + FixType.fix[ScopeT[A, *]]( + FixType.unfix[ScopeT[A, *]](scope) - (None -> key) + ) - def getLocal(key: Bindable): Option[(RecursionKind, TypedExpr[A], Scope[A])] = + def getLocal( + key: Bindable + ): Option[(RecursionKind, TypedExpr[A], Scope[A])] = FixType.unfix[ScopeT[A, *]](scope).get((None, key)) - def getGlobal(pack: PackageName, n: Bindable): Option[(RecursionKind, TypedExpr[A], Scope[A])] = + def getGlobal( + pack: PackageName, + n: Bindable + ): Option[(RecursionKind, TypedExpr[A], Scope[A])] = FixType.unfix[ScopeT[A, *]](scope).get((Some(pack), n)) } - private def nameScope[A](b: Bindable, r: RecursionKind, scope: Scope[A]): (Option[Bindable], Scope[A]) = + private def nameScope[A]( + b: Bindable, + r: RecursionKind, + scope: Scope[A] + ): (Option[Bindable], Scope[A]) = if (r.isRecursive) (Some(b), scope - b) else (None, scope) - def normalizeAll[A, V](pack: PackageName, lets: List[(Bindable, RecursionKind, TypedExpr[A])], typeEnv: TypeEnv[V]): List[(Bindable, RecursionKind, TypedExpr[A])] = { + def normalizeAll[A, V]( + pack: PackageName, + lets: List[(Bindable, RecursionKind, TypedExpr[A])], + typeEnv: TypeEnv[V] + ): List[(Bindable, RecursionKind, TypedExpr[A])] = { @annotation.tailrec - def loop(scope: Scope[A], lets: List[(Bindable, RecursionKind, TypedExpr[A])], acc: List[(Bindable, RecursionKind, TypedExpr[A])]): List[(Bindable, RecursionKind, TypedExpr[A])] = + def loop( + scope: Scope[A], + lets: List[(Bindable, RecursionKind, TypedExpr[A])], + acc: List[(Bindable, RecursionKind, TypedExpr[A])] + ): List[(Bindable, RecursionKind, TypedExpr[A])] = lets match { - case Nil => acc.reverse + case Nil => acc.reverse case (b, r, t) :: tail => // if we have a recursive value it shadows the scope val (optName, s0) = nameScope(b, r, scope) @@ -53,33 +84,44 @@ object TypedExprNormalization { } def normalizeProgram[A, V]( - p: PackageName, - fullTypeEnv: TypeEnv[V], - prog: Program[TypeEnv[V], TypedExpr[Declaration], A]): Program[TypeEnv[V], TypedExpr[Declaration], A] = { - val Program(typeEnv, lets, extDefs, stmts) = prog - val normalLets = normalizeAll(p, lets, fullTypeEnv) - Program(typeEnv, normalLets, extDefs, stmts) - } + p: PackageName, + fullTypeEnv: TypeEnv[V], + prog: Program[TypeEnv[V], TypedExpr[Declaration], A] + ): Program[TypeEnv[V], TypedExpr[Declaration], A] = { + val Program(typeEnv, lets, extDefs, stmts) = prog + val normalLets = normalizeAll(p, lets, fullTypeEnv) + Program(typeEnv, normalLets, extDefs, stmts) + } // if you have made one step of progress, use this to recurse // so we don't throw away if we don't progress more - private def normalize1[A, V](namerec: Option[Bindable], te: TypedExpr[A], scope: Scope[A], typeEnv: TypeEnv[V]): Some[TypedExpr[A]] = + private def normalize1[A, V]( + namerec: Option[Bindable], + te: TypedExpr[A], + scope: Scope[A], + typeEnv: TypeEnv[V] + ): Some[TypedExpr[A]] = normalizeLetOpt(namerec, te, scope, typeEnv) match { - case None => Some(te) - case s@Some(_) => s + case None => Some(te) + case s @ Some(_) => s } private def setType[A](expr: TypedExpr[A], tpe: Type): TypedExpr[A] = if (!tpe.sameAs(expr.getType)) Annotation(expr, tpe) else expr - /** - * if the te is not in normal form, transform it into normal form - */ - def normalizeLetOpt[A, V](namerec: Option[Bindable], te: TypedExpr[A], scope: Scope[A], typeEnv: TypeEnv[V]): Option[TypedExpr[A]] = + /** if the te is not in normal form, transform it into normal form + */ + def normalizeLetOpt[A, V]( + namerec: Option[Bindable], + te: TypedExpr[A], + scope: Scope[A], + typeEnv: TypeEnv[V] + ): Option[TypedExpr[A]] = te match { - case g@Generic(_, Annotation(term, _)) if g.getType.sameAs(term.getType) => + case g @ Generic(_, Annotation(term, _)) + if g.getType.sameAs(term.getType) => normalize1(namerec, term, scope, typeEnv) - case g@Generic(vars, in) => + case g @ Generic(vars, in) => // normalize the inside, then get all the freeBoundTyVars and // and if we can reallocate typevars to be the a, b, ... do so, // if they are the same, return none @@ -99,10 +141,11 @@ object TypedExprNormalization { case None => if (freeVars == vars.toList) None else Some(Generic(nonEmpty, in)) - case Some(gen@Generic(_, _)) => + case Some(gen @ Generic(_, _)) => // in1 could be a generic in a Some(forAll(nonEmpty, gen)) - case Some(Annotation(term, _)) if g.getType.sameAs(term.getType) => + case Some(Annotation(term, _)) + if g.getType.sameAs(term.getType) => Some(term) case Some(notGen) => Some(Generic(nonEmpty, notGen)) @@ -121,58 +164,69 @@ object TypedExprNormalization { if (notSameTpe eq term) { if (nt == tpe) None else Some(Annotation(term, nt)) - } - else Some(Annotation(notSameTpe, nt)) + } else Some(Annotation(notSameTpe, nt)) } case AnnotatedLambda(lamArgs0, expr, tag) => - val lamArgs = lamArgs0.map { case (n, tpe0) => n -> Type.normalize(tpe0) } + val lamArgs = lamArgs0.map { case (n, tpe0) => + n -> Type.normalize(tpe0) + } def doesntUseArgs(te: TypedExpr[A]): Boolean = lamArgs.forall { case (n, _) => te.notFree(n) } // assuming b is bound below lamArgs, return true if it doesn't shadow an arg def doesntShadow(b: Bindable): Boolean = !lamArgs.exists { case (n, _) => n === b } - + def matchesArgs(nel: NonEmptyList[TypedExpr[A]]): Boolean = - (nel.length == lamArgs.length) && lamArgs.iterator.zip(nel.iterator).forall { - case ((lamN, _), Local(argN, _, _)) => lamN === argN - case _ => false - } + (nel.length == lamArgs.length) && lamArgs.iterator + .zip(nel.iterator) + .forall { + case ((lamN, _), Local(argN, _, _)) => lamN === argN + case _ => false + } // we can normalize the arg to the smallest non-free var // x -> f(x) == f (eta conversion) // x -> generic(g) = generic(x -> g) if the type of x doesn't have free types with vars val e1 = normalize1(None, expr, scope, typeEnv).get e1 match { - case App(fn, aargs, _, _) if matchesArgs(aargs) && doesntUseArgs(fn) => + case App(fn, aargs, _, _) + if matchesArgs(aargs) && doesntUseArgs(fn) => normalize1(None, setType(fn, te.getType), scope, typeEnv) - case Let(arg1, ex, in, rec, tag1) if doesntUseArgs(ex) && doesntShadow(arg1) => + case Let(arg1, ex, in, rec, tag1) + if doesntUseArgs(ex) && doesntShadow(arg1) => // x -> // y = z // f(y) - //same as: - //y = z - //x -> f(y) - //avoid recomputing y - //TODO: we could reorder Lets if we have several in a row - normalize1(None, Let(arg1, ex, AnnotatedLambda(lamArgs, in, tag), rec, tag1), scope, typeEnv) - case m@Match(arg1, branches, tag1) if lamArgs.forall { case (arg, _) => arg1.notFree(arg) } => + // same as: + // y = z + // x -> f(y) + // avoid recomputing y + // TODO: we could reorder Lets if we have several in a row + normalize1( + None, + Let(arg1, ex, AnnotatedLambda(lamArgs, in, tag), rec, tag1), + scope, + typeEnv + ) + case m @ Match(arg1, branches, tag1) if lamArgs.forall { + case (arg, _) => arg1.notFree(arg) + } => // same as above: if match does not depend on lambda arg, lift it out - val b1 = branches.traverse { case (p, b) => - if (!lamArgs.exists { case (arg, _) => p.names.contains(arg) }) { - Some((p, AnnotatedLambda(lamArgs, b, tag))) - } - else None - } - b1 match { - case None => - if ((m eq expr) && (lamArgs === lamArgs0)) None - else Some(AnnotatedLambda(lamArgs, m, tag)) - case Some(bs) => - val m1 = Match(arg1, bs, tag1) - normalize1(namerec, m1, scope, typeEnv) - } + val b1 = branches.traverse { case (p, b) => + if (!lamArgs.exists { case (arg, _) => p.names.contains(arg) }) { + Some((p, AnnotatedLambda(lamArgs, b, tag))) + } else None + } + b1 match { + case None => + if ((m eq expr) && (lamArgs === lamArgs0)) None + else Some(AnnotatedLambda(lamArgs, m, tag)) + case Some(bs) => + val m1 = Match(arg1, bs, tag1) + normalize1(namerec, m1, scope, typeEnv) + } case notApp => if ((notApp eq expr) && (lamArgs === lamArgs0)) None else Some(AnnotatedLambda(lamArgs, notApp, tag)) @@ -186,7 +240,8 @@ object TypedExprNormalization { None case Global(p, n: Bindable, tpe0, tag) => scope.getGlobal(p, n).flatMap { - case (RecursionKind.NonRecursive, te, _) if Impl.isSimple(te, lambdaSimple = false) => + case (RecursionKind.NonRecursive, te, _) + if Impl.isSimple(te, lambdaSimple = false) => // TODO for a reason I don't understand, inlining lambdas here causes a stack overflow // there is probably something somewhat unsound about this substitution that I don't understand Some(te) @@ -213,15 +268,20 @@ object TypedExprNormalization { f1 match { case ws.ResolveToLambda(lamArgs, expr, _) => // (y -> z)(x) = let y = x in z - val lets = lamArgs.zip(args).map { - case ((n, ltpe), arg) => (n, setType(arg, ltpe)) + val lets = lamArgs.zip(args).map { case ((n, ltpe), arg) => + (n, setType(arg, ltpe)) } val expr2 = setType(expr, tpe) val l = TypedExpr.letAllNonRec(lets, expr2, tag) normalize1(namerec, l, scope, typeEnv) case Let(arg1, ex, in, rec, tag1) if a1.forall(_.notFree(arg1)) => - // (app (let x y z) w) == (let x y (app z w)) if w does not have x free - normalize1(namerec, Let(arg1, ex, App(in, a1, tpe, tag), rec, tag1), scope, typeEnv) + // (app (let x y z) w) == (let x y (app z w)) if w does not have x free + normalize1( + namerec, + Let(arg1, ex, App(in, a1, tpe, tag), rec, tag1), + scope, + typeEnv + ) case _ => if ((f1 eq fn) && (a1 == args) && (tpe == tpe0)) None else Some(App(f1, a1, tpe, tag)) @@ -232,7 +292,8 @@ object TypedExprNormalization { val (ni, si) = nameScope(arg, rec, scope) val ex1 = normalize1(ni, ex, si, typeEnv).get ex1 match { - case Let(ex1a, ex1ex, ex1in, RecursionKind.NonRecursive, ex1tag) if !rec.isRecursive && in.notFree(ex1a) => + case Let(ex1a, ex1ex, ex1in, RecursionKind.NonRecursive, ex1tag) + if !rec.isRecursive && in.notFree(ex1a) => // according to a SPJ paper, it is generally better // to float lets out of nesting inside in: // let foo = let bar = x in bar in foo @@ -242,14 +303,23 @@ object TypedExprNormalization { // since you are going to evaluate and keep in scope // the expression // we can lift - val l1 = Let(ex1a, ex1ex, Let(arg, ex1in, in, RecursionKind.NonRecursive, tag), RecursionKind.NonRecursive, ex1tag) + val l1 = Let( + ex1a, + ex1ex, + Let(arg, ex1in, in, RecursionKind.NonRecursive, tag), + RecursionKind.NonRecursive, + ex1tag + ) normalize1(namerec, l1, scope, typeEnv) case _ => val scopeIn = si.updated(arg, (rec, ex1, si)) val in1 = normalize1(namerec, in, scopeIn, typeEnv).get in1 match { - case Match(marg, branches, mtag) if !rec.isRecursive && marg.notFree(arg) && branches.exists { case (p, r) => p.names.contains(arg) || r.notFree(arg) } => + case Match(marg, branches, mtag) + if !rec.isRecursive && marg.notFree(arg) && branches.exists { + case (p, r) => p.names.contains(arg) || r.notFree(arg) + } => // x = y // match z: // case w: ww @@ -273,23 +343,30 @@ object TypedExprNormalization { val shouldInline = (!rec.isRecursive) && { (cnt == 1) || Impl.isSimple(ex1, lambdaSimple = true) } - val inlined = if (shouldInline) substitute(arg, ex1, in1) else None + val inlined = + if (shouldInline) substitute(arg, ex1, in1) else None inlined match { case Some(il) => normalize1(namerec, il, scope, typeEnv) case None => if ((in1 eq in) && (ex1 eq ex)) None - else normalize1(namerec, Let(arg, ex1, in1, rec, tag), scope, typeEnv) + else + normalize1( + namerec, + Let(arg, ex1, in1, rec, tag), + scope, + typeEnv + ) } - } - else { + } else { // let x = y in z if x isn't free in z = z Some(in1) } } } - case Match(_, NonEmptyList((p, e), Nil), _) if !e.freeVarsDup.exists(p.names.toSet) => + case Match(_, NonEmptyList((p, e), Nil), _) + if !e.freeVarsDup.exists(p.names.toSet) => // match x: // foo: fn // @@ -299,13 +376,25 @@ object TypedExprNormalization { // match x: // y: fn // let y = x in fn - normalize1(namerec, Let(y, arg, e, RecursionKind.NonRecursive, tag), scope, typeEnv) + normalize1( + namerec, + Let(y, arg, e, RecursionKind.NonRecursive, tag), + scope, + typeEnv + ) case Match(arg, branches, tag) => - - def ncount(shadows: Iterable[Bindable], e: TypedExpr[A]): (Int, TypedExpr[A]) = + def ncount( + shadows: Iterable[Bindable], + e: TypedExpr[A] + ): (Int, TypedExpr[A]) = // the final result of the branch is what is assigned to the name - normalizeLetOpt(None, e, shadows.foldLeft(scope)(_ - _), typeEnv) match { - case None => (0, e) + normalizeLetOpt( + None, + e, + shadows.foldLeft(scope)(_ - _), + typeEnv + ) match { + case None => (0, e) case Some(e) => (1, e) } // we can remove any bindings that aren't used in branches @@ -328,7 +417,10 @@ object TypedExprNormalization { case Pattern.WildCard => (changed0, branches1) case notWild if notWild.names.isEmpty => - val newb = branches1.init ::: ((Pattern.WildCard, branches1.last._2) :: Nil) + val newb = branches1.init ::: (( + Pattern.WildCard, + branches1.last._2 + ) :: Nil) // this newb list clearly has more than 0 elements (changed0 + 1, NonEmptyList.fromListUnsafe(newb)) case _ => @@ -349,8 +441,7 @@ object TypedExprNormalization { // we can possibly simplify this now: normalize1(namerec, m2, scope, typeEnv) } - } - else { + } else { // there has been some change, so // see if that unlocked any new changes normalize1(namerec, Match(a1, branches1a, tag), scope, typeEnv) @@ -362,21 +453,27 @@ object TypedExprNormalization { private object Impl { - def scopeMatches[A](names: Set[Bindable], scope: Scope[A], scope1: Scope[A]): Boolean = + def scopeMatches[A]( + names: Set[Bindable], + scope: Scope[A], + scope1: Scope[A] + ): Boolean = names.forall { b => (scope.getLocal(b), scope1.getLocal(b)) match { case (None, None) => true case (Some((r1, t1, s1)), Some((r2, t2, s2))) => (r1 == r2) && - (t1.void == t2.void) && - scopeMatches(t1.freeVarsDup.toSet, s1, s2) + (t1.void == t2.void) && + scopeMatches(t1.freeVarsDup.toSet, s1, s2) case _ => false } } case class WithScope[A](scope: Scope[A]) { object ResolveToLambda { - def unapply(te: TypedExpr[A]): Option[(NonEmptyList[(Bindable, Type)], TypedExpr[A], A)] = + def unapply( + te: TypedExpr[A] + ): Option[(NonEmptyList[(Bindable, Type)], TypedExpr[A], A)] = te match { case AnnotatedLambda(args, expr, ltag) => Some((args, expr, ltag)) case Global(p, n: Bindable, _, _) => @@ -388,10 +485,15 @@ object TypedExprNormalization { // we can't just replace variables if the scopes don't match. // we could also repair the scope by making a let binding // for any names that don't match (which has to be done recursively - if (scopeMatches(expr.freeVarsDup.toSet -- args.iterator.map(_._1), scope, scope1)) { + if ( + scopeMatches( + expr.freeVarsDup.toSet -- args.iterator.map(_._1), + scope, + scope1 + ) + ) { Some((args, expr, ltag)) - } - else None + } else None case _ => None } case _ => None @@ -405,10 +507,15 @@ object TypedExprNormalization { // we can't just replace variables if the scopes don't match. // we could also repair the scope by making a let binding // for any names that don't match (which has to be done recursively - if (scopeMatches(expr.freeVarsDup.toSet -- args.iterator.map(_._1), scope, scope1)) { + if ( + scopeMatches( + expr.freeVarsDup.toSet -- args.iterator.map(_._1), + scope, + scope1 + ) + ) { Some((args, expr, ltag)) - } - else None + } else None case _ => None } case _ => None @@ -422,8 +529,8 @@ object TypedExprNormalization { final def isSimple[A](ex: TypedExpr[A], lambdaSimple: Boolean): Boolean = ex match { case Literal(_, _, _) | Local(_, _, _) | Global(_, _, _, _) => true - case Annotation(t, _) => isSimple(t, lambdaSimple) - case Generic(_, t) => isSimple(t, lambdaSimple) + case Annotation(t, _) => isSimple(t, lambdaSimple) + case Generic(_, t) => isSimple(t, lambdaSimple) case AnnotatedLambda(_, _, _) => // maybe inline lambdas so we can possibly // apply (x -> f)(g) => let x = g in f @@ -433,15 +540,21 @@ object TypedExprNormalization { sealed abstract class EvalResult[A] object EvalResult { - case class Cons[A](pack: PackageName, cons: Constructor, args: List[TypedExpr[A]]) extends EvalResult[A] + case class Cons[A]( + pack: PackageName, + cons: Constructor, + args: List[TypedExpr[A]] + ) extends EvalResult[A] case class Constant[A](lit: Lit) extends EvalResult[A] } object FnArgs { - def unapply[A](te: TypedExpr[A]): Option[(TypedExpr[A], NonEmptyList[TypedExpr[A]])] = + def unapply[A]( + te: TypedExpr[A] + ): Option[(TypedExpr[A], NonEmptyList[TypedExpr[A]])] = te match { case App(fn, args, _, _) => Some((fn, args)) - case _ => None + case _ => None } } @@ -458,23 +571,29 @@ object TypedExprNormalization { case _ => None } case Let(arg, expr, in, RecursionKind.NonRecursive, _) => - evaluate(in, scope.updated(arg, (RecursionKind.NonRecursive, expr, scope))) + evaluate( + in, + scope.updated(arg, (RecursionKind.NonRecursive, expr, scope)) + ) case FnArgs(fn, args) => evaluate(fn, scope).map { - case EvalResult.Cons(p, c, ahead) => EvalResult.Cons(p, c, ahead ::: args.toList) + case EvalResult.Cons(p, c, ahead) => + EvalResult.Cons(p, c, ahead ::: args.toList) case EvalResult.Constant(c) => // this really shouldn't happen, // $COVERAGE-OFF$ - sys.error(s"unreachable: cannot apply a constant: $te => ${fn.repr} => $c") - // $COVERAGE-ON$ + sys.error( + s"unreachable: cannot apply a constant: $te => ${fn.repr} => $c" + ) + // $COVERAGE-ON$ } - case Global(pack, cons: Constructor, _, _) => Some(EvalResult.Cons(pack, cons, Nil)) + case Global(pack, cons: Constructor, _, _) => + Some(EvalResult.Cons(pack, cons, Nil)) case Global(pack, n: Bindable, _, _) => - scope.getGlobal(pack, n).flatMap { - case (_, t, s) => - // Global values never have free values, - // so it is safe to substitute into our current scope - evaluate(t, s) + scope.getGlobal(pack, n).flatMap { case (_, t, s) => + // Global values never have free values, + // so it is safe to substitute into our current scope + evaluate(t, s) } case Generic(_, in) => // if we can evaluate, we are okay @@ -488,26 +607,29 @@ object TypedExprNormalization { type Pat = Pattern[(PackageName, Constructor), Type] type Branch[A] = (Pat, TypedExpr[A]) - def maybeEvalMatch[A](m: Match[_ <: A], scope: Scope[A]): Option[TypedExpr[A]] = + def maybeEvalMatch[A]( + m: Match[_ <: A], + scope: Scope[A] + ): Option[TypedExpr[A]] = evaluate(m.arg, scope).flatMap { case EvalResult.Cons(p, c, args) => - val alen = args.length def isTotal(p: Pat): Boolean = p match { case Pattern.WildCard | Pattern.Var(_) => true - case Pattern.Named(_, p) => isTotal(p) - case Pattern.Annotation(p, _) => isTotal(p) + case Pattern.Named(_, p) => isTotal(p) + case Pattern.Annotation(p, _) => isTotal(p) case Pattern.Union(h, t) => isTotal(h) || t.exists(isTotal) - case _ => false + case _ => false } // The Option signals we can't complete def expandMatches(br: Branch[A]): Option[List[Branch[A]]] = br match { - case (ps@Pattern.PositionalStruct((p0, c0), args0), res) => - if (p0 == p && c0 == c && args0.length == alen) Some((ps, res) :: Nil) + case (ps @ Pattern.PositionalStruct((p0, c0), args0), res) => + if (p0 == p && c0 == c && args0.length == alen) + Some((ps, res) :: Nil) else Some(Nil) case (Pattern.Named(n, p), res) => expandMatches((p, res)).map { bs => @@ -519,9 +641,11 @@ object TypedExprNormalization { // The annotation is only used at inference time, the values have already been typed expandMatches((p, res)) case (Pattern.Union(h, t), r) => - (h :: t.toList).traverse { p => expandMatches((p, r)) }.map(_.flatten) - case br@(p, _) if isTotal(p) => Some(br :: Nil) - case (Pattern.ListPat(_), _) => + (h :: t.toList) + .traverse { p => expandMatches((p, r)) } + .map(_.flatten) + case br @ (p, _) if isTotal(p) => Some(br :: Nil) + case (Pattern.ListPat(_), _) => // TODO some of these patterns we could evaluate None case _ => None @@ -542,13 +666,20 @@ object TypedExprNormalization { m.branches.toList.traverse(expandMatches).map(_.flatten).flatMap { case Nil => // $COVERAGE-OFF$ - sys.error(s"no branch matched in ${m.repr} matched: $p::$c(${args.map(_.repr)})") - // $COVERAGE-ON$ - case (MaybeNamedStruct(b, pats), r) :: rest if rest.isEmpty || pats.forall(isTotal) => + sys.error( + s"no branch matched in ${m.repr} matched: $p::$c(${args.map(_.repr)})" + ) + // $COVERAGE-ON$ + case (MaybeNamedStruct(b, pats), r) :: rest + if rest.isEmpty || pats.forall(isTotal) => // If there are no more items, or all inner patterns are total, we are done // exactly one matches, this can be a sequential match - def matchAll(argPat: List[(TypedExpr[A], Pattern[(PackageName, Constructor), Type])]): TypedExpr[A] = + def matchAll( + argPat: List[ + (TypedExpr[A], Pattern[(PackageName, Constructor), Type]) + ] + ): TypedExpr[A] = argPat match { case Nil => r case (a, p) :: tail => @@ -566,7 +697,11 @@ object TypedExprNormalization { } val res = matchAll(args.zip(pats)) - Some(b.foldRight(res)(Let(_, m.arg, _, RecursionKind.NonRecursive, m.tag))) + Some( + b.foldRight(res)( + Let(_, m.arg, _, RecursionKind.NonRecursive, m.tag) + ) + ) case h :: t => // more than one branch might match, wait till runtime val m1 = Match(m.arg, NonEmptyList(h, t), m.tag) @@ -575,12 +710,14 @@ object TypedExprNormalization { } case EvalResult.Constant(li @ Lit.Integer(i)) => - def makeLet(p: Pattern[(PackageName, Constructor), Type]): Option[List[Bindable]] = + def makeLet( + p: Pattern[(PackageName, Constructor), Type] + ): Option[List[Bindable]] = p match { case Pattern.Named(v, p) => makeLet(p).map(v :: _) - case Pattern.WildCard => Some(Nil) - case Pattern.Var(v) => Some(v :: Nil) + case Pattern.WildCard => Some(Nil) + case Pattern.Var(v) => Some(v :: Nil) case Pattern.Annotation(p, _) => makeLet(p) case Pattern.Literal(Lit.Integer(j)) => if (j == i) Some(Nil) @@ -588,7 +725,9 @@ object TypedExprNormalization { case Pattern.Union(h, t) => (h :: t).toList.iterator.map(makeLet).reduce(_.orElse(_)) // $COVERAGE-OFF$ this is ill-typed so should be unreachable - case Pattern.PositionalStruct(_, _) | Pattern.ListPat(_) | Pattern.StrPat(_) | Pattern.Literal(Lit.Str(_)) => None + case Pattern.PositionalStruct(_, _) | Pattern.ListPat(_) | + Pattern.StrPat(_) | Pattern.Literal(Lit.Str(_)) => + None // $COVERAGE-ON$ } @@ -596,10 +735,11 @@ object TypedExprNormalization { def find[X, Y](ls: List[X])(fn: X => Option[Y]): Option[Y] = ls match { case Nil => None - case h :: t => fn(h) match { - case None => find(t)(fn) - case some => some - } + case h :: t => + fn(h) match { + case None => find(t)(fn) + case some => some + } } find[Branch[A], TypedExpr[A]](m.branches.toList) { case (p, r) => diff --git a/core/src/main/scala/org/bykn/bosatsu/UnusedLetCheck.scala b/core/src/main/scala/org/bykn/bosatsu/UnusedLetCheck.scala index 94ab40499..2d2342cd5 100644 --- a/core/src/main/scala/org/bykn/bosatsu/UnusedLetCheck.scala +++ b/core/src/main/scala/org/bykn/bosatsu/UnusedLetCheck.scala @@ -1,7 +1,14 @@ package org.bykn.bosatsu import cats.Applicative -import cats.data.{Chain, NonEmptyList, Validated, ValidatedNec, Writer, NonEmptyChain} +import cats.data.{ + Chain, + NonEmptyList, + Validated, + ValidatedNec, + Writer, + NonEmptyChain +} import cats.implicits._ import Expr._ @@ -10,9 +17,14 @@ import Identifier.Bindable object UnusedLetCheck { private[this] val ap = Applicative[Writer[Chain[(Bindable, Region)], *]] - private[this] val empty: Writer[Chain[(Bindable, Region)], Set[Bindable]] = ap.pure(Set.empty) + private[this] val empty: Writer[Chain[(Bindable, Region)], Set[Bindable]] = + ap.pure(Set.empty) - private[this] def checkArg(arg: Bindable, reg: => Region, w: Writer[Chain[(Bindable, Region)], Set[Bindable]]) = + private[this] def checkArg( + arg: Bindable, + reg: => Region, + w: Writer[Chain[(Bindable, Region)], Set[Bindable]] + ) = w.flatMap { free => if (free(arg)) ap.pure(free - arg) else { @@ -21,14 +33,16 @@ object UnusedLetCheck { } } - private[this] def loop[A: HasRegion](e: Expr[A]): Writer[Chain[(Bindable, Region)], Set[Bindable]] = + private[this] def loop[A: HasRegion]( + e: Expr[A] + ): Writer[Chain[(Bindable, Region)], Set[Bindable]] = e match { case Annotation(expr, _, _) => loop(expr) case Generic(_, in) => loop(in) case Lambda(args, expr, _) => args.toList.foldRight(loop(expr)) { (arg, res) => - checkArg(arg._1, HasRegion.region(e), res) + checkArg(arg._1, HasRegion.region(e), res) } case Let(arg, expr, in, rec, _) => val exprCheck = loop(expr) @@ -38,13 +52,15 @@ object UnusedLetCheck { if (rec.isRecursive) exprCheck.map(_ - arg) else exprCheck // the region of the let isn't directly tracked, but // it would start with the whole region starts and end at expr - val inCheck = checkArg(arg, - { + val inCheck = checkArg( + arg, { val wholeRegion = HasRegion.region(e) val endRegion = HasRegion.region(expr) val bindRegion = wholeRegion.copy(end = endRegion.end) bindRegion - }, loop(in)) + }, + loop(in) + ) (exprRes, inCheck).mapN(_ ++ _) case Local(name, _) => // this is a free variable: @@ -57,40 +73,47 @@ object UnusedLetCheck { // TODO: patterns need their own region val branchRegions = NonEmptyList.fromListUnsafe( - branches.toList.scanLeft((HasRegion.region(arg), Option.empty[Region])) { case ((prev, _), (_, caseExpr)) => - // between the previous expression and the case is the pattern - (HasRegion.region(caseExpr), Some(Region(prev.end, HasRegion.region(caseExpr).start))) - } - .collect { case (_, Some(r)) => r } + branches.toList + .scanLeft((HasRegion.region(arg), Option.empty[Region])) { + case ((prev, _), (_, caseExpr)) => + // between the previous expression and the case is the pattern + ( + HasRegion.region(caseExpr), + Some(Region(prev.end, HasRegion.region(caseExpr).start)) + ) + } + .collect { case (_, Some(r)) => r } ) - val bcheck = branchRegions.zip(branches).traverse { case (region, (pat, expr)) => - loop(expr).flatMap { frees => - val thisPatNames = pat.names - val unused = thisPatNames.filterNot(frees) - val nextFrees = frees -- thisPatNames + val bcheck = branchRegions + .zip(branches) + .traverse { case (region, (pat, expr)) => + loop(expr).flatMap { frees => + val thisPatNames = pat.names + val unused = thisPatNames.filterNot(frees) + val nextFrees = frees -- thisPatNames - ap.pure(nextFrees).tell(Chain.fromSeq(unused.map((_, region)))) + ap.pure(nextFrees).tell(Chain.fromSeq(unused.map((_, region)))) + } } - } - .map(_.combineAll) + .map(_.combineAll) (argCheck, bcheck).mapN(_ ++ _) } - /** - * Check for any unused lets, defs, or pattern bindings - */ - def check[A: HasRegion](e: Expr[A]): ValidatedNec[(Bindable, Region), Unit] = { + /** Check for any unused lets, defs, or pattern bindings + */ + def check[A: HasRegion]( + e: Expr[A] + ): ValidatedNec[(Bindable, Region), Unit] = { val (chain, _) = loop(e).run NonEmptyChain.fromChain(chain) match { - case None => Validated.valid(()) + case None => Validated.valid(()) case Some(nec) => Validated.invalid(nec.distinct) } } - /** - * Return the free Bindable names in this expression - */ + /** Return the free Bindable names in this expression + */ def freeBound[A](e: Expr[A]): Set[Bindable] = loop(e)(HasRegion.instance(_ => Region(0, 0))).run._2 } diff --git a/core/src/main/scala/org/bykn/bosatsu/Value.scala b/core/src/main/scala/org/bykn/bosatsu/Value.scala index d8948bb8f..efa6ffeb2 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Value.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Value.scala @@ -4,50 +4,47 @@ import cats.data.NonEmptyList import java.math.BigInteger import scala.collection.immutable.SortedMap -/** - * If we later determine that this performance matters - * and this wrapping is hurting, we could replace - * Value with a less structured type and put - * all the reflection into unapply calls but keep - * most of the API - */ +/** If we later determine that this performance matters and this wrapping is + * hurting, we could replace Value with a less structured type and put all the + * reflection into unapply calls but keep most of the API + */ sealed abstract class Value { import Value._ def asFn: NonEmptyList[Value] => Value = this match { case FnValue(f) => f - case other => + case other => // $COVERAGE-OFF$this should be unreachable sys.error(s"invalid cast to Fn: $other") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } def asSum: SumValue = this match { case s: SumValue => s - case _ => + case _ => // $COVERAGE-OFF$this should be unreachable sys.error(s"invalid cast to SumValue: $this") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } - def asProduct:ProductValue = + def asProduct: ProductValue = this match { case p: ProductValue => p - case _ => + case _ => // $COVERAGE-OFF$this should be unreachable sys.error(s"invalid cast to ProductValue: $this") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } def asExternal: ExternalValue = this match { case ex: ExternalValue => ex - case _ => + case _ => // $COVERAGE-OFF$this should be unreachable sys.error(s"invalid cast to ExternalValue: $this") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } final def applyAll(args: NonEmptyList[Value]): Value = @@ -58,7 +55,7 @@ object Value { sealed abstract class ProductValue extends Value { def toList: List[Value] = this match { - case UnitValue => Nil + case UnitValue => Nil case ConsValue(head, tail) => head :: tail.toList } @@ -70,7 +67,9 @@ object Value { if (ix <= 0) head else loop(tail, ix - 1) case UnitValue => - throw new IllegalArgumentException(s"exhausted index at $ix on ${this}.get($idx)") + throw new IllegalArgumentException( + s"exhausted index at $ix on ${this}.get($idx)" + ) } loop(this, idx) @@ -80,7 +79,7 @@ object Value { object ProductValue { def fromList(ps: List[Value]): ProductValue = ps match { - case Nil => UnitValue + case Nil => UnitValue case h :: tail => ConsValue(h, fromList(tail)) } } @@ -89,10 +88,12 @@ object Value { case class ConsValue(head: Value, tail: ProductValue) extends ProductValue { override val hashCode = (head, tail).hashCode } - final class SumValue(val variant: Int, val value: ProductValue) extends Value { + final class SumValue(val variant: Int, val value: ProductValue) + extends Value { override def equals(that: Any) = that match { - case s: SumValue => (s eq this) || ((variant == s.variant) && (value == s.value)) + case s: SumValue => + (s eq this) || ((variant == s.variant) && (value == s.value)) case _ => false } override def hashCode: Int = @@ -108,7 +109,8 @@ object Value { (0 until constCount).map(new SumValue(_, UnitValue)).toArray def apply(variant: Int, value: ProductValue): SumValue = - if ((value == UnitValue) && ((variant & sizeMask) == 0)) constants(variant) + if ((value == UnitValue) && ((variant & sizeMask) == 0)) + constants(variant) else new SumValue(variant, value) } @@ -123,11 +125,12 @@ object Value { case class SimpleFnValue(toFn: NonEmptyList[Value] => Value) extends Arg - def apply(toFn: NonEmptyList[Value] => Value): FnValue = new FnValue(SimpleFnValue(toFn)) - def unapply(fnValue: FnValue): Some[NonEmptyList[Value] => Value] = Some(fnValue.arg.toFn) + def unapply(fnValue: FnValue): Some[NonEmptyList[Value] => Value] = Some( + fnValue.arg.toFn + ) val identity: FnValue = FnValue(vs => vs.head) } @@ -141,7 +144,7 @@ object Value { def unapply(v: Value): Option[(Value, Value)] = v match { case ConsValue(a, ConsValue(b, UnitValue)) => Some((a, b)) - case _ => None + case _ => None } def apply(a: Value, b: Value): ProductValue = @@ -149,24 +152,22 @@ object Value { } object Tuple { - /** - * Tuples are encoded as: - * (1, 2, 3) => TupleCons(1, TupleCons(2, TupleCons(3, ()))) - * since a Tuple(a, b) is encoded as - * ConsValue(a, ConsValue(b, UnitValue)) - * this gives double wrapping - */ + + /** Tuples are encoded as: (1, 2, 3) => TupleCons(1, TupleCons(2, + * TupleCons(3, ()))) since a Tuple(a, b) is encoded as ConsValue(a, + * ConsValue(b, UnitValue)) this gives double wrapping + */ def unapply(v: Value): Option[List[Value]] = v match { case TupleCons(a, b) => unapply(b).map(a :: _) case UnitValue => Some(Nil) - case _ => None + case _ => None } def fromList(vs: List[Value]): ProductValue = vs match { - case Nil => UnitValue + case Nil => UnitValue case h :: tail => TupleCons(h, fromList(tail)) } @@ -185,7 +186,7 @@ object Value { def fromLit(l: Lit): Value = l match { - case Lit.Str(s) => ExternalValue(s) + case Lit.Str(s) => ExternalValue(s) case Lit.Integer(i) => ExternalValue(i) } @@ -195,7 +196,7 @@ object Value { def unapply(v: Value): Option[BigInteger] = v match { case ExternalValue(v: BigInteger) => Some(v) - case _ => None + case _ => None } } @@ -204,7 +205,7 @@ object Value { def unapply(v: Value): Option[String] = v match { case ExternalValue(str: String) => Some(str) - case _ => None + case _ => None } } @@ -221,10 +222,9 @@ object Value { else if ((s.variant == 1)) { s.value match { case ConsValue(head, UnitValue) => Some(Some(head)) - case _ => None + case _ => None } - } - else None + } else None case _ => None } } @@ -240,11 +240,11 @@ object Value { case s: SumValue => if (s.variant == 1) { s.value match { - case ConsValue(head, ConsValue(rest, UnitValue)) => Some((head, rest)) + case ConsValue(head, ConsValue(rest, UnitValue)) => + Some((head, rest)) case _ => None } - } - else None + } else None case _ => None } } @@ -253,7 +253,7 @@ object Value { @annotation.tailrec def go(vs: List[Value], acc: Value): Value = vs match { - case Nil => acc + case Nil => acc case h :: tail => go(tail, Cons(h, acc)) } go(items.reverse, VNil) @@ -277,11 +277,11 @@ object Value { val v = fn.applyAll( NonEmptyList( - Tuple.fromList(v1 :: null :: Nil), - Tuple.fromList(v2 :: null :: Nil) :: Nil) - ) - .asSum - .variant + Tuple.fromList(v1 :: null :: Nil), + Tuple.fromList(v2 :: null :: Nil) :: Nil + ) + ).asSum + .variant if (v == 0) -1 else if (v == 1) 0 else if (v == 2) 1 @@ -293,15 +293,18 @@ object Value { } } - //enum Tree: Empty, Branch(size: Int, height: Int, key: a, left: Tree[a], right: Tree[a]) - //struct Dict[k, v](ord: Order[(k, v)], tree: Tree[(k, v)]) + // enum Tree: Empty, Branch(size: Int, height: Int, key: a, left: Tree[a], right: Tree[a]) + // struct Dict[k, v](ord: Order[(k, v)], tree: Tree[(k, v)]) def unapply(v: Value): Option[SortedMap[Value, Value]] = v match { case ConsValue(ordFn: FnValue, ConsValue(tree, UnitValue)) => implicit val ord: Ordering[Value] = keyOrderingFromOrdFn(ordFn) - def treeToList(t: Value, acc: SortedMap[Value, Value]): SortedMap[Value, Value] = { + def treeToList( + t: Value, + acc: SortedMap[Value, Value] + ): SortedMap[Value, Value] = { val v = t.asSum if (v.variant == 0) acc // empty else { @@ -313,7 +316,7 @@ object Value { case other => // $COVERAGE-OFF$ sys.error(s"ill-shaped: $other") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } } } @@ -326,21 +329,24 @@ object Value { } val strOrdFn: FnValue = - FnValue { - case NonEmptyList(tup1, tup2 :: Nil) => - (tup1, tup2) match { - case (Tuple(ExternalValue(k1: String) :: _), Tuple(ExternalValue(k2: String) :: _)) => - Comparison.fromInt(k1.compareTo(k2)) - case _ => - // $COVERAGE-OFF$ - sys.error(s"ill-typed in String Dict order: $tup1, $tup2") - // $COVERAGE-ON$ - } + FnValue { case NonEmptyList(tup1, tup2 :: Nil) => + (tup1, tup2) match { + case ( + Tuple(ExternalValue(k1: String) :: _), + Tuple(ExternalValue(k2: String) :: _) + ) => + Comparison.fromInt(k1.compareTo(k2)) + case _ => + // $COVERAGE-OFF$ + sys.error(s"ill-typed in String Dict order: $tup1, $tup2") + // $COVERAGE-ON$ } + } def fromStringKeys(kvs: List[(String, Value)]): Value = { val allItems: Array[(String, Value)] = kvs.toMap.toArray - java.util.Arrays.sort(allItems, Ordering[String].on { (kv: (String, Value)) => kv._1 }) + java.util.Arrays + .sort(allItems, Ordering[String].on { (kv: (String, Value)) => kv._1 }) val empty = (BigInteger.ZERO, BigInteger.ZERO, SumValue(0, UnitValue)) @@ -353,14 +359,21 @@ object Value { val (rh, rz, right) = makeTree(mid + 1, end) val h = lh.max(rh).add(BigInteger.ONE) val z = lz.add(rz).add(BigInteger.ONE) - (h, z, SumValue(1, - ProductValue.fromList( - ExternalValue(z) :: - ExternalValue(h) :: - Tuple.fromList(ExternalValue(k) :: v :: Nil) :: - left :: - right :: - Nil))) + ( + h, + z, + SumValue( + 1, + ProductValue.fromList( + ExternalValue(z) :: + ExternalValue(h) :: + Tuple.fromList(ExternalValue(k) :: v :: Nil) :: + left :: + right :: + Nil + ) + ) + ) } val (_, _, tree) = makeTree(0, allItems.length) diff --git a/core/src/main/scala/org/bykn/bosatsu/ValueToDoc.scala b/core/src/main/scala/org/bykn/bosatsu/ValueToDoc.scala index 51735c492..2e3f5d395 100644 --- a/core/src/main/scala/org/bykn/bosatsu/ValueToDoc.scala +++ b/core/src/main/scala/org/bykn/bosatsu/ValueToDoc.scala @@ -14,15 +14,14 @@ import JsonEncodingError.IllTyped case class ValueToDoc(getDefinedType: Type.Const => Option[DefinedType[Any]]) { - /** - * Convert a typechecked value to a Document representation - * - * Note, we statically build the conversion function if it is possible at - * all, after that, only value errors can occur - * - * this code ASSUMES the type is correct. If not, we may return - * incorrect data if it is not clearly illtyped - */ + /** Convert a typechecked value to a Document representation + * + * Note, we statically build the conversion function if it is possible at + * all, after that, only value errors can occur + * + * this code ASSUMES the type is correct. If not, we may return incorrect + * data if it is not clearly illtyped + */ def toDoc(tpe: Type): Value => Either[IllTyped, Doc] = { type Fn = Value => Either[IllTyped, Doc] @@ -39,33 +38,31 @@ case class ValueToDoc(getDefinedType: Type.Const => Option[DefinedType[Any]]) { case Some(fn) => fn case None => val res: Eval[Fn] = Eval.later(tpe match { - case Type.IntType => - { - case ExternalValue(v: BigInteger) => - Right(Doc.str(v)) - case other => - // $COVERAGE-OFF$this should be unreachable - Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ - } - case Type.StrType => - { - case ExternalValue(v: String) => - Right(Document[Lit].document(Lit.Str(v))) - case other => - // $COVERAGE-OFF$this should be unreachable - Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ - } + case Type.IntType => { + case ExternalValue(v: BigInteger) => + Right(Doc.str(v)) + case other => + // $COVERAGE-OFF$this should be unreachable + Left(IllTyped(revPath.reverse, tpe, other)) + // $COVERAGE-ON$ + } + case Type.StrType => { + case ExternalValue(v: String) => + Right(Document[Lit].document(Lit.Str(v))) + case other => + // $COVERAGE-OFF$this should be unreachable + Left(IllTyped(revPath.reverse, tpe, other)) + // $COVERAGE-ON$ + } case Type.UnitType => // encode this as null { case UnitValue => Right(Doc.text("()")) - case other => + case other => // $COVERAGE-OFF$this should be unreachable Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ - } + // $COVERAGE-ON$ + } case Type.ListT(t1) => lazy val inner = loop(t1, tpe :: revPath).value @@ -73,12 +70,14 @@ case class ValueToDoc(getDefinedType: Type.Const => Option[DefinedType[Any]]) { case VList(vs) => vs.traverse(inner) .map { inners => - Doc.char('[') + (Doc.lineOrEmpty + commaBlock(inners) + Doc.lineOrEmpty).aligned + Doc.char(']') + Doc.char('[') + (Doc.lineOrEmpty + commaBlock( + inners + ) + Doc.lineOrEmpty).aligned + Doc.char(']') } case other => // $COVERAGE-OFF$this should be unreachable Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ + // $COVERAGE-ON$ } case Type.DictT(Type.StrType, vt) => lazy val inner = loop(vt, tpe :: revPath).value @@ -86,25 +85,30 @@ case class ValueToDoc(getDefinedType: Type.Const => Option[DefinedType[Any]]) { { case VDict(d) => - d.toList.traverse { case (k, v) => - k match { - case Str(kstr) => - inner(v).map { vdoc => - (docStr.document(Lit.Str(kstr)) + (Doc.char(':') + Doc.line + vdoc).nested(4)).grouped - } - case other => - // $COVERAGE-OFF$this should be unreachable - Left(IllTyped(revPath.reverse, tpe, other)) + d.toList + .traverse { case (k, v) => + k match { + case Str(kstr) => + inner(v).map { vdoc => + (docStr.document(Lit.Str(kstr)) + (Doc.char( + ':' + ) + Doc.line + vdoc).nested(4)).grouped + } + case other => + // $COVERAGE-OFF$this should be unreachable + Left(IllTyped(revPath.reverse, tpe, other)) // $COVERAGE-ON$ + } + } + .map { kvs => + Doc.char('{') + (Doc.lineOrEmpty + commaBlock( + kvs + ) + Doc.lineOrEmpty).aligned + Doc.char('}') } - } - .map { kvs => - Doc.char('{') + (Doc.lineOrEmpty + commaBlock(kvs) + Doc.lineOrEmpty).aligned + Doc.char('}') - } case other => // $COVERAGE-OFF$this should be unreachable Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ + // $COVERAGE-ON$ } case Type.Tuple(ts) => val p1 = tpe :: revPath @@ -117,12 +121,13 @@ case class ValueToDoc(getDefinedType: Type.Const => Option[DefinedType[Any]]) { .toVector .traverse { case (a, fn) => fn(a) } .map { items => - Doc.char('(') + (Doc.lineOrEmpty + commaBlock(items) + Doc.char(',') + Doc.lineOrEmpty).aligned + Doc.char(')') + Doc.char('(') + (Doc.lineOrEmpty + commaBlock(items) + Doc + .char(',') + Doc.lineOrEmpty).aligned + Doc.char(')') } case other => // $COVERAGE-OFF$this should be unreachable Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ + // $COVERAGE-ON$ } case Type.ForAll(_, inner) => @@ -131,7 +136,7 @@ case class ValueToDoc(getDefinedType: Type.Const => Option[DefinedType[Any]]) { case Type.TyVar(_) => // we don't really know what to do with { _ => Right(Doc.text("")) } - case fn@Type.Fun(_, _) => + case fn @ Type.Fun(_, _) => def arity(fn: Type): Int = fn match { case Type.Fun(_, dest) => @@ -147,7 +152,7 @@ case class ValueToDoc(getDefinedType: Type.Const => Option[DefinedType[Any]]) { case other => // $COVERAGE-OFF$this should be unreachable Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ + // $COVERAGE-ON$ } case _ => // We can have complicated recursion here, we @@ -168,9 +173,11 @@ case class ValueToDoc(getDefinedType: Type.Const => Option[DefinedType[Any]]) { case Some(dt) => val cons = dt.constructors val (_, targs) = Type.applicationArgs(tpe) - val replaceMap = dt.typeParams.zip(targs).toMap[Type.Var, Type] + val replaceMap = + dt.typeParams.zip(targs).toMap[Type.Var, Type] - lazy val resInner: Map[Int, (Constructor, List[(String, Fn)])] = + lazy val resInner + : Map[Int, (Constructor, List[(String, Fn)])] = cons.zipWithIndex .traverse { case (cf, idx) => val rec = cf.args.traverse { case (field, t) => @@ -183,7 +190,11 @@ case class ValueToDoc(getDefinedType: Type.Const => Option[DefinedType[Any]]) { .map(_.toMap) .value - def params(variant: Int, params: List[Value], src: Value): Either[IllTyped, Doc] = + def params( + variant: Int, + params: List[Value], + src: Value + ): Either[IllTyped, Doc] = resInner.get(variant) match { case None => Left(IllTyped(revPath.reverse, tpe, src)) @@ -193,42 +204,43 @@ case class ValueToDoc(getDefinedType: Type.Const => Option[DefinedType[Any]]) { .zip(fields) .traverse { case (v, (nm, fn)) => fn(v).map { vdoc => - (Doc.text(nm) + Doc.char(':') + Doc.lineOrSpace + vdoc).nested(4) + (Doc.text(nm) + Doc.char( + ':' + ) + Doc.lineOrSpace + vdoc).nested(4) } } .map { paramsDoc => val nm = Doc.text(name.asString) if (paramsDoc.isEmpty) nm else { - nm + Doc.space + - (Doc.char('{') + (Doc.line + commaBlock(paramsDoc)).nested(4) + Doc.line + Doc.char('}')).grouped + nm + Doc.space + + (Doc.char('{') + (Doc.line + commaBlock( + paramsDoc + )).nested(4) + Doc.line + Doc + .char('}')).grouped } } - } - else Left(IllTyped(revPath.reverse, tpe, src)) + } else Left(IllTyped(revPath.reverse, tpe, src)) } - dt.dataFamily match { case DataFamily.NewType => // the outer wrapping is so we add it back { v => params(0, v :: Nil, v) } - case DataFamily.Struct => - { - case prod: ProductValue => - params(0, prod.toList, prod) + case DataFamily.Struct => { + case prod: ProductValue => + params(0, prod.toList, prod) - case other => - Left(IllTyped(revPath.reverse, tpe, other)) - } - case DataFamily.Enum => - { - case s: SumValue => - params(s.variant, s.value.toList, s) - case a => - Left(IllTyped(revPath.reverse, tpe, a)) - } + case other => + Left(IllTyped(revPath.reverse, tpe, other)) + } + case DataFamily.Enum => { + case s: SumValue => + params(s.variant, s.value.toList, s) + case a => + Left(IllTyped(revPath.reverse, tpe, a)) + } case DataFamily.Nat => // this is nat-like // TODO, maybe give a warning @@ -239,8 +251,8 @@ case class ValueToDoc(getDefinedType: Type.Const => Option[DefinedType[Any]]) { Left(IllTyped(revPath.reverse, tpe, other)) } } - } - }) + } + }) // put the result in the cache before we compute it // so we can recurse successCache.put(tpe, res) diff --git a/core/src/main/scala/org/bykn/bosatsu/ValueToJson.scala b/core/src/main/scala/org/bykn/bosatsu/ValueToJson.scala index 8df50bca7..591666b10 100644 --- a/core/src/main/scala/org/bykn/bosatsu/ValueToJson.scala +++ b/core/src/main/scala/org/bykn/bosatsu/ValueToJson.scala @@ -15,19 +15,17 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { def canEncodeToNull(t: Type): Boolean = t match { - case Type.UnitType => true + case Type.UnitType => true case Type.OptionT(inner) => // if the inside of an Option cannot be null, we can use null // to represent None !canEncodeToNull(inner) case Type.ForAll(_, inner) => canEncodeToNull(inner) - case _ => false + case _ => false } - - /** - * Is a given type supported for Json conversion - */ + /** Is a given type supported for Json conversion + */ def supported(t: Type): Either[UnsupportedType, Unit] = { // if we are currently working on a Type // we assume it is supported, and it isn't @@ -39,13 +37,13 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { Left(UnsupportedType(NonEmptyList(t, working).reverse)) t match { - case _ if working.contains(t) => good + case _ if working.contains(t) => good case Type.IntType | Type.StrType | Type.BoolType | Type.UnitType => good - case Type.OptionT(inner) => loop(inner, t :: working) - case Type.ListT(inner) => loop(inner, t :: working ) + case Type.OptionT(inner) => loop(inner, t :: working) + case Type.ListT(inner) => loop(inner, t :: working) case Type.DictT(Type.StrType, inner) => loop(inner, t :: working) - case Type.ForAll(_, _) => bad - case Type.TyVar(_) | Type.TyMeta(_) => bad + case Type.ForAll(_, _) => bad + case Type.TyVar(_) | Type.TyMeta(_) => bad case Type.Tuple(ts) => val w1 = t :: working ts.traverse_(loop(_, w1)) @@ -59,7 +57,8 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { case Some(dt) => val cons = dt.constructors val (_, targs) = Type.applicationArgs(consOrApply) - val replaceMap = dt.typeParams.zip(targs).toMap[Type.Var, Type] + val replaceMap = + dt.typeParams.zip(targs).toMap[Type.Var, Type] cons.traverse_ { cf => cf.args.traverse_ { case (_, t) => @@ -80,19 +79,20 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { // $COVERAGE-OFF$ case Left(u) => sys.error(s"should have only called on a supported type: $u") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } - /** - * Convert a typechecked value to Json - * - * Note, we statically build the conversion function if it is possible at - * all, after that, only value errors can occur - * - * this code ASSUMES the type is correct. If not, we may return - * incorrect data. - */ - def toJson(tpe: Type): Either[UnsupportedType, Value => Either[IllTyped, Json]] = { + /** Convert a typechecked value to Json + * + * Note, we statically build the conversion function if it is possible at + * all, after that, only value errors can occur + * + * this code ASSUMES the type is correct. If not, we may return incorrect + * data. + */ + def toJson( + tpe: Type + ): Either[UnsupportedType, Value => Either[IllTyped, Json]] = { type Fn = Value => Either[IllTyped, Json] // when we complete a custom type, we put it in here @@ -105,60 +105,57 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { case Some(fn) => fn case None => val res: Eval[Fn] = Eval.later(tpe match { - case Type.IntType => - { - case ExternalValue(v: BigInteger) => - Right(Json.JNumberStr(v.toString)) - // $COVERAGE-OFF$this should be unreachable - case other => - Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ - } - case Type.StrType => - { - case ExternalValue(v: String) => - Right(Json.JString(v)) + case Type.IntType => { + case ExternalValue(v: BigInteger) => + Right(Json.JNumberStr(v.toString)) + // $COVERAGE-OFF$this should be unreachable + case other => + Left(IllTyped(revPath.reverse, tpe, other)) + // $COVERAGE-ON$ + } + case Type.StrType => { + case ExternalValue(v: String) => + Right(Json.JString(v)) + // $COVERAGE-OFF$this should be unreachable + case other => + Left(IllTyped(revPath.reverse, tpe, other)) + // $COVERAGE-ON$ + } + case Type.BoolType => { + case True => Right(Json.JBool(true)) + case False => Right(Json.JBool(false)) + case other => // $COVERAGE-OFF$this should be unreachable - case other => - Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ - } - case Type.BoolType => - { - case True => Right(Json.JBool(true)) - case False => Right(Json.JBool(false)) - case other => - // $COVERAGE-OFF$this should be unreachable - Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ - } + Left(IllTyped(revPath.reverse, tpe, other)) + // $COVERAGE-ON$ + } case Type.UnitType => // encode this as null { case UnitValue => Right(Json.JNull) - case other => + case other => // $COVERAGE-OFF$this should be unreachable Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ - } - case opt@Type.OptionT(t1) => + // $COVERAGE-ON$ + } + case opt @ Type.OptionT(t1) => lazy val inner = loop(t1, tpe :: revPath).value if (canEncodeToNull(opt)) { - // not a nested option + // not a nested option { - case VOption(None) => Right(Json.JNull) + case VOption(None) => Right(Json.JNull) case VOption(Some(a)) => inner(a) case other => Left(IllTyped(revPath.reverse, tpe, other)) } - } - else { + } else { { case VOption(None) => Right(Json.JArray(Vector.empty)) - case VOption(Some(a)) => inner(a).map { j => Json.JArray(Vector(j)) } + case VOption(Some(a)) => + inner(a).map { j => Json.JArray(Vector(j)) } case other => Left(IllTyped(revPath.reverse, tpe, other)) } @@ -174,27 +171,28 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { case other => // $COVERAGE-OFF$this should be unreachable Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ + // $COVERAGE-ON$ } case Type.DictT(Type.StrType, vt) => lazy val inner = loop(vt, tpe :: revPath).value { case VDict(d) => - d.toList.traverse { case (k, v) => - k match { - case Str(kstr) => inner(v).map((kstr, _)) - case other => - // $COVERAGE-OFF$this should be unreachable - Left(IllTyped(revPath.reverse, tpe, other)) + d.toList + .traverse { case (k, v) => + k match { + case Str(kstr) => inner(v).map((kstr, _)) + case other => + // $COVERAGE-OFF$this should be unreachable + Left(IllTyped(revPath.reverse, tpe, other)) // $COVERAGE-ON$ + } } - } - .map(Json.JObject(_)) + .map(Json.JObject(_)) case other => // $COVERAGE-OFF$this should be unreachable Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ + // $COVERAGE-ON$ } case Type.Tuple(ts) => val p1 = tpe :: revPath @@ -210,7 +208,7 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { case other => // $COVERAGE-OFF$this should be unreachable Left(IllTyped(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ + // $COVERAGE-ON$ } case Type.ForAll(_, inner) => @@ -227,24 +225,26 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { getDefinedType(const) match { case Some(dt) => Right(dt) case None => - Left(UnsupportedType(NonEmptyList(tpe, revPath).reverse)) + Left( + UnsupportedType(NonEmptyList(tpe, revPath).reverse) + ) } case None => Left(UnsupportedType(NonEmptyList(tpe, revPath).reverse)) }) dt.dataFamily match { - case DataFamily.Nat => - { - case ExternalValue(b: BigInteger) => - Right(Json.JNumberStr(b.toString)) - case other => - Left(IllTyped(revPath.reverse, tpe, other)) - } + case DataFamily.Nat => { + case ExternalValue(b: BigInteger) => + Right(Json.JNumberStr(b.toString)) + case other => + Left(IllTyped(revPath.reverse, tpe, other)) + } case notNat => val cons = dt.constructors val (_, targs) = Type.applicationArgs(tpe) - val replaceMap = dt.typeParams.zip(targs).toMap[Type.Var, Type] + val replaceMap = + dt.typeParams.zip(targs).toMap[Type.Var, Type] val resInner: Eval[Map[Int, List[(String, Fn)]]] = cons.zipWithIndex @@ -258,7 +258,6 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { } .map(_.toMap) - notNat match { case DataFamily.NewType => lazy val inner = resInner.value.head._2.head._2 @@ -273,13 +272,13 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { val plist = prod.toList if (plist.size == size) { - plist.zip(productsInner) + plist + .zip(productsInner) .traverse { case (p, (key, f)) => f(p).map((key, _)) } .map { ps => Json.JObject(ps) } - } - else { + } else { Left(IllTyped(revPath.reverse, tpe, prod)) } @@ -297,41 +296,42 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { case Some(fn) => val vlist = s.value.toList if (vlist.size == fn.size) { - vlist.zip(fn) + vlist + .zip(fn) .traverse { case (p, (key, f)) => f(p).map((key, _)) } .map { ps => Json.JObject(ps) } - } - else Left(IllTyped(revPath.reverse, tpe, s)) + } else Left(IllTyped(revPath.reverse, tpe, s)) case None => Left(IllTyped(revPath.reverse, tpe, s)) } case a => Left(IllTyped(revPath.reverse, tpe, a)) - } - } + } + } } - }) + }) // put the result in the cache before we compute it // so we can recurse successCache.put(tpe, res) res - } + } supported(tpe).map(_ => loop(tpe, Nil).value) } - /** - * Convert a Json to a Value - * - * Note, we statically build the conversion function if it is possible at - * all, after that, only value errors can occur - * - * this code ASSUMES the type is correct. If not, we may return - * incorrect data. - */ - def toValue(tpe: Type): Either[UnsupportedType, Json => Either[IllTypedJson, Value]] = { + /** Convert a Json to a Value + * + * Note, we statically build the conversion function if it is possible at + * all, after that, only value errors can occur + * + * this code ASSUMES the type is correct. If not, we may return incorrect + * data. + */ + def toValue( + tpe: Type + ): Either[UnsupportedType, Json => Either[IllTypedJson, Value]] = { type Fn = Json => Either[IllTypedJson, Value] // when we complete a custom type, we put it in here @@ -341,192 +341,199 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { successCache.get(tpe) match { case Some(res) => res case None => - val res: Eval[Json => Either[IllTypedJson, Value]] = Eval.later(tpe match { - case Type.IntType => - { + val res: Eval[Json => Either[IllTypedJson, Value]] = + Eval.later(tpe match { + case Type.IntType => { case Json.JBigInteger(b) => Right(ExternalValue(b)) case other => Left(IllTypedJson(revPath.reverse, tpe, other)) } - case Type.StrType => - { + case Type.StrType => { case Json.JString(v) => Right(ExternalValue(v)) case other => // $COVERAGE-OFF$this should be unreachable Left(IllTypedJson(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ + // $COVERAGE-ON$ } - case Type.BoolType => - { + case Type.BoolType => { case Json.JBool(value) => Right(if (value) True else False) case other => // $COVERAGE-OFF$this should be unreachable Left(IllTypedJson(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ - } - case Type.UnitType => - // encode this as null - { - case Json.JNull => Right(UnitValue) - case other => - // $COVERAGE-OFF$this should be unreachable - Left(IllTypedJson(revPath.reverse, tpe, other)) - // $COVERAGE-ON$ + // $COVERAGE-ON$ } - case opt@Type.OptionT(t1) => - if (canEncodeToNull(opt)) { - // not a nested option - lazy val inner = loop(t1, tpe :: revPath).value - + case Type.UnitType => + // encode this as null { - case Json.JNull => Right(VOption.none) - case notNull => inner(notNull).map(VOption.some(_)) + case Json.JNull => Right(UnitValue) + case other => + // $COVERAGE-OFF$this should be unreachable + Left(IllTypedJson(revPath.reverse, tpe, other)) + // $COVERAGE-ON$ } - } - else { - // we can't encode Option[Option[T]] as null or not, so we encode - // as list of 0 or 1 items + case opt @ Type.OptionT(t1) => + if (canEncodeToNull(opt)) { + // not a nested option + lazy val inner = loop(t1, tpe :: revPath).value + + { + case Json.JNull => Right(VOption.none) + case notNull => inner(notNull).map(VOption.some(_)) + } + } else { + // we can't encode Option[Option[T]] as null or not, so we encode + // as list of 0 or 1 items - lazy val inner = loop(t1, tpe :: revPath).value + lazy val inner = loop(t1, tpe :: revPath).value - { - case Json.JArray(items) if items.lengthCompare(1) <= 0 => - items.headOption match { - case None => Right(VOption.none) - case Some(a) => inner(a).map(VOption.some(_)) - } - case other => + { + case Json.JArray(items) if items.lengthCompare(1) <= 0 => + items.headOption match { + case None => Right(VOption.none) + case Some(a) => inner(a).map(VOption.some(_)) + } + case other => Left(IllTypedJson(revPath.reverse, tpe, other)) + } } - } - case Type.ListT(t) => - lazy val inner = loop(t, tpe :: revPath).value + case Type.ListT(t) => + lazy val inner = loop(t, tpe :: revPath).value - { - case Json.JArray(vs) => - vs.toVector - .traverse(inner) - .map { vs => VList(vs.toList) } - case other => - // $COVERAGE-OFF$this should be unreachable - Left(IllTypedJson(revPath.reverse, tpe, other)) + { + case Json.JArray(vs) => + vs.toVector + .traverse(inner) + .map { vs => VList(vs.toList) } + case other => + // $COVERAGE-OFF$this should be unreachable + Left(IllTypedJson(revPath.reverse, tpe, other)) // $COVERAGE-ON$ - } - case Type.DictT(Type.StrType, vt) => - lazy val inner = loop(vt, tpe :: revPath).value + } + case Type.DictT(Type.StrType, vt) => + lazy val inner = loop(vt, tpe :: revPath).value - { - case Json.JObject(items) => - items.traverse { case (k, v) => - inner(v).map((k, _)) - } - .map { kvs => - VDict.fromStringKeys(kvs) - } - case other => - // $COVERAGE-OFF$this should be unreachable - Left(IllTypedJson(revPath.reverse, tpe, other)) + { + case Json.JObject(items) => + items + .traverse { case (k, v) => + inner(v).map((k, _)) + } + .map { kvs => + VDict.fromStringKeys(kvs) + } + case other => + // $COVERAGE-OFF$this should be unreachable + Left(IllTypedJson(revPath.reverse, tpe, other)) // $COVERAGE-ON$ - } - case Type.Tuple(ts) => - val p1 = tpe :: revPath - lazy val inners = ts.traverse(loop(_, p1)).value + } + case Type.Tuple(ts) => + val p1 = tpe :: revPath + lazy val inners = ts.traverse(loop(_, p1)).value - { - case ary@Json.JArray(as) => - if (as.size == inners.size) { - as.zip(inners) - .toVector - .traverse { case (a, fn) => fn(a) } - .map { vs => Tuple.fromList(vs.toList) } - } - else Left(IllTypedJson(revPath.reverse, tpe, ary)) - case other => - // $COVERAGE-OFF$this should be unreachable - Left(IllTypedJson(revPath.reverse, tpe, other)) + { + case ary @ Json.JArray(as) => + if (as.size == inners.size) { + as.zip(inners) + .toVector + .traverse { case (a, fn) => fn(a) } + .map { vs => Tuple.fromList(vs.toList) } + } else Left(IllTypedJson(revPath.reverse, tpe, ary)) + case other => + // $COVERAGE-OFF$this should be unreachable + Left(IllTypedJson(revPath.reverse, tpe, other)) // $COVERAGE-ON$ - } - - case Type.ForAll(_, inner) => - // we assume the generic positions don't matter and to continue - loop(inner, tpe :: revPath).value - case _ => - val fullPath = tpe :: revPath - - val dt = - get(Type.rootConst(tpe) match { - case Some(Type.TyConst(const)) => - getDefinedType(const) match { - case Some(dt) => Right(dt) - case None => - Left(UnsupportedType(NonEmptyList(tpe, revPath).reverse)) - } - case None => - Left(UnsupportedType(NonEmptyList(tpe, revPath).reverse)) - }) + } - val resInner: Eval[ - List[(Int, List[(String, Json => Either[IllTypedJson, Value])])] - ] = { + case Type.ForAll(_, inner) => + // we assume the generic positions don't matter and to continue + loop(inner, tpe :: revPath).value + case _ => + val fullPath = tpe :: revPath + + val dt = + get(Type.rootConst(tpe) match { + case Some(Type.TyConst(const)) => + getDefinedType(const) match { + case Some(dt) => Right(dt) + case None => + Left( + UnsupportedType(NonEmptyList(tpe, revPath).reverse) + ) + } + case None => + Left(UnsupportedType(NonEmptyList(tpe, revPath).reverse)) + }) + + val resInner: Eval[ + List[ + (Int, List[(String, Json => Either[IllTypedJson, Value])]) + ] + ] = { val cons = dt.constructors val (_, targs) = Type.applicationArgs(tpe) - val replaceMap = dt.typeParams.zip(targs).toMap[Type.Var, Type] + val replaceMap = + dt.typeParams.zip(targs).toMap[Type.Var, Type] cons.zipWithIndex .traverse { case (cf, idx) => - cf.args.traverse { case (pn, t) => - val subsT = Type.substituteVar(t, replaceMap) - loop(subsT, fullPath) - .map((pn.asString, _)) - } - .map { pair => (idx, pair) } + cf.args + .traverse { case (pn, t) => + val subsT = Type.substituteVar(t, replaceMap) + loop(subsT, fullPath) + .map((pn.asString, _)) + } + .map { pair => (idx, pair) } } - } + } - dt.dataFamily match { - case DataFamily.NewType => - // there is one single arg constructor - lazy val inner = resInner.value.head._2.head._2 + dt.dataFamily match { + case DataFamily.NewType => + // there is one single arg constructor + lazy val inner = resInner.value.head._2.head._2 - { j => inner(j) } - case DataFamily.Struct | DataFamily.Enum => + { j => inner(j) } + case DataFamily.Struct | DataFamily.Enum => // This is lazy because we don't want to run // the Evals until we have the first value lazy val mapping: List[(Int, Map[String, (Int, Fn)])] = // if we are in here, all constituent parts can be solved - resInner.value.map { case (idx, lst) => - (idx, - lst - .iterator - .zipWithIndex - .map { case ((nm, fn), idx) => (nm, (idx, fn)) } - .toMap) - } + resInner.value.map { case (idx, lst) => + ( + idx, + lst.iterator.zipWithIndex.map { + case ((nm, fn), idx) => (nm, (idx, fn)) + }.toMap + ) + } { - case obj@Json.JObject(_) => + case obj @ Json.JObject(_) => val keySet = obj.toMap.keySet - def run(cand: List[(Int, Map[String, (Int, Fn)])]): Either[IllTypedJson, Value] = + def run( + cand: List[(Int, Map[String, (Int, Fn)])] + ): Either[IllTypedJson, Value] = cand match { case Nil => Left(IllTypedJson(revPath.reverse, tpe, obj)) - case (variant, decode) :: _ if keySet == decode.keySet => + case (variant, decode) :: _ + if keySet == decode.keySet => val itemArray = new Array[Value](keySet.size) - obj.items.foldM(itemArray) { case (ary, (k, v)) => - val (idx, fn) = decode(k) - fn(v).map { value => - ary(idx) = value - ary + obj.items + .foldM(itemArray) { case (ary, (k, v)) => + val (idx, fn) = decode(k) + fn(v).map { value => + ary(idx) = value + ary + } + } + .map { ary => + val prod = ProductValue.fromList(ary.toList) + if (dt.isStruct) prod + else SumValue(variant, prod) } - } - .map { ary => - val prod = ProductValue.fromList(ary.toList) - if (dt.isStruct) prod - else SumValue(variant, prod) - } case _ :: tail => run(tail) } @@ -534,15 +541,15 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { case other => Left(IllTypedJson(revPath.reverse, tpe, other)) } - case DataFamily.Nat => - // this is a nat like type which we encode into integers - { - case Json.JBigInteger(bi) => - Right(ExternalValue(bi)) - case other => - Left(IllTypedJson(revPath.reverse, tpe, other)) - } - } + case DataFamily.Nat => + // this is a nat like type which we encode into integers + { + case Json.JBigInteger(bi) => + Right(ExternalValue(bi)) + case other => + Left(IllTypedJson(revPath.reverse, tpe, other)) + } + } }) successCache.put(tpe, res) @@ -552,11 +559,13 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { supported(tpe).map(_ => loop(tpe, Nil).value) } - /** - * Given a type return the function to convert it a function - * if it is not a function, we consider it a function of 0-arity - */ - def valueFnToJsonFn(t: Type): Either[UnsupportedType, (Int, Value => Either[DataError, Json.JArray => Either[DataError, Json]])] = + /** Given a type return the function to convert it a function if it is not a + * function, we consider it a function of 0-arity + */ + def valueFnToJsonFn(t: Type): Either[ + UnsupportedType, + (Int, Value => Either[DataError, Json.JArray => Either[DataError, Json]]) + ] = t match { case Type.Fun((args, res)) => (args.traverse(toValue(_)), toJson(res)).mapN { (argsFn, resFn) => @@ -565,38 +574,44 @@ case class ValueToJson(getDefinedType: Type.Const => Option[DefinedType[Any]]) { val arity = argsFn.size val argsFnVector = argsFn.toList.toVector - (arity, { - case Value.FnValue(fn) => - - val jsonFn = { (inputs: Json.JArray) => - if (inputs.toVector.size != arity) Left(IllTypedJson(Nil, t, inputs)) - else { - // we know arity >= 1 because it is a function, so the fromListUnsafe will succeed - inputs.toVector - .zip(argsFnVector) - .traverse { case (a, fn) => fn(a) } - .map { vect => fn(NonEmptyList.fromListUnsafe(vect.toList)) } - .flatMap(resFn) + ( + arity, + { + case Value.FnValue(fn) => + val jsonFn = { (inputs: Json.JArray) => + if (inputs.toVector.size != arity) + Left(IllTypedJson(Nil, t, inputs)) + else { + // we know arity >= 1 because it is a function, so the fromListUnsafe will succeed + inputs.toVector + .zip(argsFnVector) + .traverse { case (a, fn) => fn(a) } + .map { vect => + fn(NonEmptyList.fromListUnsafe(vect.toList)) + } + .flatMap(resFn) + } } - } - Right(jsonFn) - case notFn => Left(IllTyped(Nil, t, notFn)) - }) + Right(jsonFn) + case notFn => Left(IllTyped(Nil, t, notFn)) + } + ) } case _ => // this isn't a function at all toJson(t).map { (fn: (Value) => Either[DataError, Json]) => - - (0, fn.andThen { either => - either.map { result => - - { (args: Json.JArray) => - if (args.toVector.isEmpty) Right(result) - else Left(IllTypedJson(Nil, t, args)) + ( + 0, + fn.andThen { either => + either.map { result => + { (args: Json.JArray) => + if (args.toVector.isEmpty) Right(result) + else Left(IllTypedJson(Nil, t, args)) + } } } - }) + ) } } @@ -606,8 +621,11 @@ sealed abstract class JsonEncodingError object JsonEncodingError { sealed abstract class DataError extends JsonEncodingError - final case class UnsupportedType(path: NonEmptyList[Type]) extends JsonEncodingError + final case class UnsupportedType(path: NonEmptyList[Type]) + extends JsonEncodingError - final case class IllTyped(path: List[Type], tpe: Type, value: Value) extends DataError - final case class IllTypedJson(path: List[Type], tpe: Type, value: Json) extends DataError + final case class IllTyped(path: List[Type], tpe: Type, value: Value) + extends DataError + final case class IllTypedJson(path: List[Type], tpe: Type, value: Json) + extends DataError } diff --git a/core/src/main/scala/org/bykn/bosatsu/Variance.scala b/core/src/main/scala/org/bykn/bosatsu/Variance.scala index 73fcd2af6..c26acc454 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Variance.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Variance.scala @@ -8,35 +8,35 @@ sealed abstract class Variance { def unary_- : Variance = this match { case Contravariant => Covariant - case Covariant => Contravariant - case topOrBottom => topOrBottom + case Covariant => Contravariant + case topOrBottom => topOrBottom } // if you have f[x] the variance of the result is the arg of f times variance of x def *(that: Variance): Variance = (this, that) match { - case (Phantom, _) => Phantom - case (_, Phantom) => Phantom - case (Invariant, _) => Invariant - case (_, Invariant) => Invariant - case (Covariant, r) => r + case (Phantom, _) => Phantom + case (_, Phantom) => Phantom + case (Invariant, _) => Invariant + case (_, Invariant) => Invariant + case (Covariant, r) => r case (Contravariant, Contravariant) => Covariant - case (Contravariant, Covariant) => Contravariant + case (Contravariant, Covariant) => Contravariant } - /** - * Variance forms a lattice with Phantom at the bottom and Invariant at the top. - */ + /** Variance forms a lattice with Phantom at the bottom and Invariant at the + * top. + */ def +(that: Variance): Variance = (this, that) match { - case (Phantom, r) => r - case (r, Phantom) => r - case (Invariant, _) => Invariant - case (_, Invariant) => Invariant - case (Covariant, Covariant) => Covariant + case (Phantom, r) => r + case (r, Phantom) => r + case (Invariant, _) => Invariant + case (_, Invariant) => Invariant + case (Covariant, Covariant) => Covariant case (Contravariant, Contravariant) => Contravariant - case (Covariant, Contravariant) => Invariant - case (Contravariant, Covariant) => Invariant + case (Covariant, Contravariant) => Invariant + case (Contravariant, Covariant) => Invariant } } object Variance { @@ -51,7 +51,8 @@ object Variance { def contra: Variance = Contravariant def in: Variance = Invariant - val all: List[Variance] = Phantom :: Covariant :: Contravariant :: Invariant :: Nil + val all: List[Variance] = + Phantom :: Covariant :: Contravariant :: Invariant :: Nil implicit val varianceBoundedSemilattice: BoundedSemilattice[Variance] = new BoundedSemilattice[Variance] { @@ -67,15 +68,15 @@ object Variance { case Phantom => if (right == Phantom) 0 else -1 case Covariant => right match { - case Phantom => 1 - case Covariant => 0 + case Phantom => 1 + case Covariant => 0 case Contravariant | Invariant => -1 } - case Contravariant => + case Contravariant => right match { case Phantom | Covariant => 1 - case Contravariant => 0 - case Invariant => -1 + case Contravariant => 0 + case Invariant => -1 } case Invariant => if (right == Invariant) 0 else 1 } diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/python/Code.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/python/Code.scala index 9bac78718..37ff99534 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/python/Code.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/python/Code.scala @@ -21,9 +21,9 @@ object Code { sealed abstract class Expression extends ValueLike with Code { def identOrParens: Expression = this match { - case i: Code.Ident => i - case p@Code.Parens(_) => p - case other => Code.Parens(other) + case i: Code.Ident => i + case p @ Code.Parens(_) => p + case other => Code.Parens(other) } def apply(args: Expression*): Apply = @@ -60,7 +60,7 @@ object Code { def statements: NonEmptyList[Statement] = this match { case Block(ss) => ss - case notBlock => NonEmptyList.one(notBlock) + case notBlock => NonEmptyList.one(notBlock) } def +:(stmt: Statement): Statement = @@ -84,8 +84,8 @@ object Code { case Pass => vl case _ => vl match { - case wv@WithValue(_, _) => this +: wv - case _ => WithValue(this, vl) + case wv @ WithValue(_, _) => this +: wv + case _ => WithValue(this, vl) } } } @@ -96,11 +96,12 @@ object Code { private def maybePar(c: Expression): Doc = c match { case Lambda(_, _) | Ternary(_, _, _) => par(toDoc(c)) - case _ => toDoc(c) + case _ => toDoc(c) } private def iflike(name: String, cond: Doc, body: Doc): Doc = - Doc.text(name) + Doc.space + cond + Doc.char(':') + (Doc.hardLine + body).nested(4) + Doc.text(name) + Doc.space + cond + Doc.char(':') + (Doc.hardLine + body) + .nested(4) private val trueDoc = Doc.text("True") private val falseDoc = Doc.text("False") @@ -118,42 +119,59 @@ object Code { def exprToDoc(expr: Expression): Doc = expr match { case PyInt(bi) => Doc.text(bi.toString) - case PyString(s) => Doc.char('"') + Doc.text(StringUtil.escape('"', s)) + Doc.char('"') + case PyString(s) => + Doc.char('"') + Doc.text(StringUtil.escape('"', s)) + Doc.char('"') case PyBool(b) => if (b) trueDoc else falseDoc - case Ident(i) => Doc.text(i) - case o@Op(_, _, _) => o.toDoc - case Parens(inner@Parens(_)) => exprToDoc(inner) - case Parens(p) => par(exprToDoc(p)) + case Ident(i) => Doc.text(i) + case o @ Op(_, _, _) => o.toDoc + case Parens(inner @ Parens(_)) => exprToDoc(inner) + case Parens(p) => par(exprToDoc(p)) case SelectItem(x, i) => maybePar(x) + Doc.char('[') + Doc.str(i) + Doc.char(']') case SelectRange(x, os, oe) => - val middle = os.fold(Doc.empty)(exprToDoc) + Doc.char(':') + oe.fold(Doc.empty)(exprToDoc) + val middle = os.fold(Doc.empty)(exprToDoc) + Doc.char(':') + oe.fold( + Doc.empty + )(exprToDoc) maybePar(x) + (Doc.char('[') + middle + Doc.char(']')).nested(4) case Ternary(ift, cond, iff) => // python parses the else condition as the rest of experssion, so // no need to put parens around it - maybePar(ift) + spaceIfSpace + maybePar(cond) + spaceElseSpace + exprToDoc(iff) + maybePar(ift) + spaceIfSpace + maybePar( + cond + ) + spaceElseSpace + exprToDoc(iff) case MakeTuple(items) => items match { - case Nil => unitDoc + case Nil => unitDoc case h :: Nil => par(exprToDoc(h) + Doc.comma).nested(4) - case twoOrMore => par(Doc.intercalate(Doc.comma + Doc.line, twoOrMore.map(exprToDoc)).grouped).nested(4) + case twoOrMore => + par( + Doc + .intercalate(Doc.comma + Doc.line, twoOrMore.map(exprToDoc)) + .grouped + ).nested(4) } case MakeList(items) => val inner = items.map(exprToDoc) - (Doc.char('[') + Doc.intercalate(Doc.comma + Doc.line, inner).grouped + Doc.char(']')).nested(4) + (Doc.char('[') + Doc + .intercalate(Doc.comma + Doc.line, inner) + .grouped + Doc.char(']')).nested(4) case Lambda(args, res) => - lamDoc + Doc.intercalate(Doc.comma + Doc.space, args.map(exprToDoc)) + colonSpace + exprToDoc(res) + lamDoc + Doc.intercalate( + Doc.comma + Doc.space, + args.map(exprToDoc) + ) + colonSpace + exprToDoc(res) case Apply(fn, args) => - maybePar(fn) + par(Doc.intercalate(Doc.comma + Doc.line, args.map(exprToDoc)).grouped).nested(4) + maybePar(fn) + par( + Doc.intercalate(Doc.comma + Doc.line, args.map(exprToDoc)).grouped + ).nested(4) case DotSelect(left, right) => val ld = left match { case PyInt(_) => par(exprToDoc(left)) - case _ => exprToDoc(left) + case _ => exprToDoc(left) } ld + Doc.char('.') + exprToDoc(right) } @@ -161,14 +179,16 @@ object Code { def toDoc(c: Code): Doc = c match { case expr: Expression => exprToDoc(expr) - case Call(ap) => toDoc(ap) + case Call(ap) => toDoc(ap) case ClassDef(name, ex, body) => val exDoc = if (ex.isEmpty) Doc.empty else par(Doc.intercalate(Doc.comma + Doc.space, ex.map(toDoc))) - Doc.text("class") + Doc.space + Doc.text(name.name) + exDoc + Doc.char(':') + (Doc.hardLine + + Doc.text("class") + Doc.space + Doc.text(name.name) + exDoc + Doc.char( + ':' + ) + (Doc.hardLine + toDoc(body)).nested(4) case IfStatement(conds, Some(Pass)) => @@ -178,7 +198,9 @@ object Code { val condsDoc = conds.map { case (x, b) => (toDoc(x), toDoc(b)) } val i1 = iflike("if", condsDoc.head._1, condsDoc.head._2) val i2 = condsDoc.tail.map { case (x, b) => iflike("elif", x, b) } - val el = optElse.fold(Doc.empty) { els => Doc.hardLine + elseColon + (Doc.hardLine + toDoc(els)).nested(4) } + val el = optElse.fold(Doc.empty) { els => + Doc.hardLine + elseColon + (Doc.hardLine + toDoc(els)).nested(4) + } Doc.intercalate(Doc.hardLine, i1 :: i2) + el @@ -187,18 +209,23 @@ object Code { case Def(nm, args, body) => defDoc + Doc.space + Doc.text(nm.name) + - par(Doc.intercalate(Doc.comma + Doc.lineOrSpace, args.map(toDoc))).nested(4) + Doc.char(':') + (Doc.hardLine + toDoc(body)).nested(4) + par(Doc.intercalate(Doc.comma + Doc.lineOrSpace, args.map(toDoc))) + .nested(4) + Doc.char(':') + (Doc.hardLine + toDoc(body)).nested(4) case Return(expr) => retSpaceDoc + toDoc(expr) case Assign(nm, expr) => toDoc(nm) + spaceEqSpace + toDoc(expr) - case Pass => Doc.text("pass") + case Pass => Doc.text("pass") case While(cond, body) => - whileDoc + Doc.space + toDoc(cond) + Doc.char(':') + (Doc.hardLine + toDoc(body)).nested(4) + whileDoc + Doc.space + toDoc(cond) + Doc.char( + ':' + ) + (Doc.hardLine + toDoc(body)).nested(4) case Import(name, aliasOpt) => // import name as alias val imp = Doc.text("import") + Doc.space + Doc.text(name) - aliasOpt.fold(imp) { a => imp + Doc.space + Doc.text("as") + Doc.space + toDoc(a) } + aliasOpt.fold(imp) { a => + imp + Doc.space + Doc.text("as") + Doc.space + toDoc(a) + } } ///////////////////////// @@ -218,12 +245,16 @@ object Code { def simplify = this } // Binary operator used for +, -, and, == etc... - case class Op(left: Expression, op: Operator, right: Expression) extends Expression { + case class Op(left: Expression, op: Operator, right: Expression) + extends Expression { // operators like + can associate // def toDoc: Doc = { // invariant: all items in right associate - def loop(left: Expression, rights: NonEmptyList[(Operator, Expression)]): Doc = + def loop( + left: Expression, + rights: NonEmptyList[(Operator, Expression)] + ): Doc = // a op1 b op2 c if op1 and op2 associate no need for a parens // left match { // case Op(_, @@ -234,7 +265,7 @@ object Code { else loop(Parens(left), rights) case leftNotOp => rights.head match { - case (ol, or@Op(_, o2, _)) if !ol.associates(o2) => + case (ol, or @ Op(_, o2, _)) if !ol.associates(o2) => // we we can't break rights.head because the ops // don't associate. We wrap or in Parens val rights1 = (ol, Parens(or)) @@ -245,17 +276,22 @@ object Code { case (ol, rightNotOp) => rights.tail match { case Nil => - maybePar(leftNotOp) + Doc.space + Doc.text(ol.name) + Doc.space + maybePar(rightNotOp) + maybePar(leftNotOp) + Doc.space + Doc.text( + ol.name + ) + Doc.space + maybePar(rightNotOp) case (o2, r2) :: rest => // everything in rights associate - val leftDoc = maybePar(leftNotOp) + Doc.space + Doc.text(ol.name) + Doc.space + val leftDoc = maybePar(leftNotOp) + Doc.space + Doc.text( + ol.name + ) + Doc.space if (ol.associates(o2)) { leftDoc + loop(rightNotOp, NonEmptyList((o2, r2), rest)) - } - else { + } else { // we need to put a parens ending after rightNotOp // leftNotOp ol (rightNotOp o2 r2 :: rest) - leftDoc + par(loop(rightNotOp, NonEmptyList((o2, r2), rest))) + leftDoc + par( + loop(rightNotOp, NonEmptyList((o2, r2), rest)) + ) } } @@ -270,11 +306,11 @@ object Code { this match { case Op(PyInt(a), io: IntOp, PyInt(b)) => PyInt(io(a, b)) - case Op(i@PyInt(a), Const.Times, right) => + case Op(i @ PyInt(a), Const.Times, right) => if (a == BigInteger.ZERO) i else if (a == BigInteger.ONE) right.simplify else right.simplify.evalTimes(i) - case Op(left, Const.Times, i@PyInt(b)) => + case Op(left, Const.Times, i @ PyInt(b)) => if (b == BigInteger.ZERO) i else if (b == BigInteger.ONE) left.simplify else { @@ -282,14 +318,14 @@ object Code { if (l1 == left) this else (l1.evalTimes(i)) } - case Op(i@PyInt(a), Const.Plus, right) => + case Op(i @ PyInt(a), Const.Plus, right) => if (a == BigInteger.ZERO) right.simplify else { val r1 = right.simplify // put the constant on the right r1.evalPlus(i) } - case Op(left, Const.Plus, i@PyInt(b)) => + case Op(left, Const.Plus, i @ PyInt(b)) => if (b == BigInteger.ZERO) left.simplify else { val l1 = left.simplify @@ -301,16 +337,15 @@ object Code { // right associate ll.evalPlus(rl.evalPlus(i)) case Const.Minus => - //(ll - rl) + i == ll - (rl - i) + // (ll - rl) + i == ll - (rl - i) ll.evalMinus(rl.evalMinus(i)) case _ => this } case _ => this } - } - else (l1.evalPlus(i)) + } else (l1.evalPlus(i)) } - case Op(i@PyInt(_), Const.Minus, right) => + case Op(i @ PyInt(_), Const.Minus, right) => val r1 = right.simplify if (r1 == right) { r1 match { @@ -320,9 +355,9 @@ object Code { // right associate rl.evalPlus(rr.evalPlus(i)) case Const.Minus => - //i - (rl - rr) + // i - (rl - rr) rr match { - case ri@PyInt(_) => + case ri @ PyInt(_) => Op(i.evalPlus(ri), Const.Minus, rl) case _ => this } @@ -330,9 +365,8 @@ object Code { } case _ => this } - } - else (i.evalMinus(r1)) - case Op(left, Const.Minus, i@PyInt(b)) => + } else (i.evalMinus(r1)) + case Op(left, Const.Minus, i @ PyInt(b)) => if (b == BigInteger.ZERO) left.simplify else { val l1 = left.simplify @@ -344,16 +378,15 @@ object Code { // (ll + rl) - i == ll + (rl - i) ll.evalPlus(rl.evalMinus(i)) case Const.Minus => - //(ll - rl) - i == ll - (rl + i) + // (ll - rl) - i == ll - (rl + i) ll.evalMinus(rl.evalPlus(i)) case _ => this } case _ => this } - } - else (l1.evalMinus(i)) + } else (l1.evalMinus(i)) } - case Op(a, Const.Eq, b) if a == b => Const.True + case Op(a, Const.Eq, b) if a == b => Const.True case Op(a, Const.Gt | Const.Lt | Const.Neq, b) if a == b => Const.False case Op(PyInt(a), Const.Gt, PyInt(b)) => fromBoolean(a.compareTo(b) > 0) @@ -365,13 +398,13 @@ object Code { fromBoolean(a == b) case Op(a, Const.And, b) => a.simplify match { - case Const.True => b.simplify + case Const.True => b.simplify case Const.False => Const.False case a1 => b.simplify match { - case Const.True => a1 + case Const.True => a1 case Const.False => Const.False - case b1 => Op(a1, Const.And, b1) + case b1 => Op(a1, Const.And, b1) } } case _ => @@ -379,8 +412,7 @@ object Code { val r1 = right.simplify if ((l1 != left) || (r1 != right)) { Op(l1, op, r1).simplify - } - else { + } else { (left, op) match { case (Op(ll, Const.Plus, lr), Const.Plus) => // right associate @@ -403,7 +435,8 @@ object Code { case class Parens(expr: Expression) extends Expression { def simplify: Expression = expr.simplify match { - case x@(PyBool(_) | Ident(_) | PyInt(_) | PyString(_) | Parens(_)) => x + case x @ (PyBool(_) | Ident(_) | PyInt(_) | PyString(_) | Parens(_)) => + x case exprS => Parens(exprS) } } @@ -419,10 +452,15 @@ object Code { } } // foo[a:b] - case class SelectRange(arg: Expression, start: Option[Expression], end: Option[Expression]) extends Expression { + case class SelectRange( + arg: Expression, + start: Option[Expression], + end: Option[Expression] + ) extends Expression { def simplify = SelectRange(arg, start.map(_.simplify), end.map(_.simplify)) } - case class Ternary(ifTrue: Expression, cond: Expression, ifFalse: Expression) extends Expression { + case class Ternary(ifTrue: Expression, cond: Expression, ifFalse: Expression) + extends Expression { def simplify: Expression = cond.simplify match { case PyBool(b) => @@ -454,15 +492,18 @@ object Code { ///////////////////////// // this prepares an expression with a number of statements - case class WithValue(statement: Statement, value: ValueLike) extends ValueLike { + case class WithValue(statement: Statement, value: ValueLike) + extends ValueLike { def +:(stmt: Statement): WithValue = WithValue(stmt +: statement, value) def :+(stmt: Statement): WithValue = WithValue(statement :+ stmt, value) } - case class IfElse(conds: NonEmptyList[(Expression, ValueLike)], elseCond: ValueLike) extends ValueLike - + case class IfElse( + conds: NonEmptyList[(Expression, ValueLike)], + elseCond: ValueLike + ) extends ValueLike ///////////////////////// // Here are all the Statements @@ -470,29 +511,42 @@ object Code { case class Call(sideEffect: Apply) extends Statement // extends are really certain DotSelects, but we can't constrain that much - case class ClassDef(name: Ident, extendList: List[Expression], body: Statement) extends Statement + case class ClassDef( + name: Ident, + extendList: List[Expression], + body: Statement + ) extends Statement case class Block(stmts: NonEmptyList[Statement]) extends Statement - case class IfStatement(conds: NonEmptyList[(Expression, Statement)], elseCond: Option[Statement]) extends Statement - case class Def(name: Ident, args: List[Ident], body: Statement) extends Statement + case class IfStatement( + conds: NonEmptyList[(Expression, Statement)], + elseCond: Option[Statement] + ) extends Statement + case class Def(name: Ident, args: List[Ident], body: Statement) + extends Statement case class Return(expr: Expression) extends Statement case class Assign(target: Expression, value: Expression) extends Statement case object Pass extends Statement case class While(cond: Expression, body: Statement) extends Statement case class Import(modname: String, alias: Option[Ident]) extends Statement - def ifStatement(conds: NonEmptyList[(Expression, Statement)], elseCond: Option[Statement]): Statement = { + def ifStatement( + conds: NonEmptyList[(Expression, Statement)], + elseCond: Option[Statement] + ): Statement = { val simpConds = conds.map { case (e, s) => (e.simplify, s) } val allBranches: NonEmptyList[(Expression, Statement)] = elseCond match { case Some(s) => simpConds :+ ((Code.Const.True, s)) - case None => simpConds + case None => simpConds } // we know the returned expression is never a constant expression - def untilTrue(lst: List[(Expression, Statement)]): (List[(Expression, Statement)], Statement) = + def untilTrue( + lst: List[(Expression, Statement)] + ): (List[(Expression, Statement)], Statement) = lst match { - case Nil => (Nil, Pass) + case Nil => (Nil, Pass) case (Code.Const.True, last) :: _ => (Nil, last) case head :: tail => val (rest, e) = untilTrue(tail) @@ -535,15 +589,15 @@ object Code { def flatten(s: Statement): List[Statement] = s match { - case Pass => Nil + case Pass => Nil case Block(stmts) => stmts.toList.flatMap(flatten) - case single => single :: Nil + case single => single :: Nil } def block(stmt: Statement, rest: Statement*): Statement = { val all = (stmt :: rest.toList).flatMap(flatten) all match { - case Nil => Pass + case Nil => Pass case one :: Nil => one case head :: tail => Block(NonEmptyList(head, tail)) @@ -560,14 +614,14 @@ object Code { unapply(stmts.last).map { case (s0, i, e) => val s1 = NonEmptyList.fromList(stmts.init) match { - case None => s0 + case None => s0 case Some(inits) => Block(inits) :+ s0 } (s1, i, e) } case Assign(i @ Ident(_), expr) => Some((Pass, i, expr)) - case _ => None + case _ => None } } @@ -583,10 +637,10 @@ object Code { conds.map { case (c, v) => (c, toReturn(v)) }, - Some(toReturn(elseCond))) + Some(toReturn(elseCond)) + ) } - // boolean expressions can contain side effects // this runs the side effects but discards // and resulting value @@ -608,7 +662,7 @@ object Code { def litToExpr(lit: Lit): Expression = lit match { - case Lit.Str(s) => PyString(s) + case Lit.Str(s) => PyString(s) case Lit.Integer(bi) => PyInt(bi) } @@ -620,7 +674,6 @@ object Code { else if (i == 1L) Const.One else PyInt(BigInteger.valueOf(i)) - def fromBoolean(b: Boolean): Expression = if (b) Code.Const.True else Code.Const.False @@ -628,9 +681,9 @@ object Code { def associates(that: Operator): Boolean = { // true if (a this b) that c == a this (b that c) this match { - case Const.Plus => (that == Const.Plus) || (that == Const.Minus) + case Const.Plus => (that == Const.Plus) || (that == Const.Minus) case Const.Minus => false - case Const.And => that == Const.And + case Const.And => that == Const.And case Const.Times => // (a * b) * c == a * (b * c) // (a * b) + c != a * (b + c) @@ -643,11 +696,11 @@ object Code { sealed abstract class IntOp(nm: String) extends Operator(nm) { def apply(a: BigInteger, b: BigInteger): BigInteger = this match { - case Const.Plus => a.add(b) + case Const.Plus => a.add(b) case Const.Minus => a.subtract(b) case Const.Times => a.multiply(b) - case Const.Div => PredefImpl.divBigInteger(a, b) - case Const.Mod => PredefImpl.modBigInteger(a, b) + case Const.Div => PredefImpl.divBigInteger(a, b) + case Const.Mod => PredefImpl.modBigInteger(a, b) } } @@ -676,13 +729,36 @@ object Code { "[_A-Za-z][_0-9A-Za-z]*".r.pattern val pyKeywordList: Set[String] = Set( - "and", "del", "from", "not", "while", - "as", "elif", "global", "or", "with", - "assert", "else", "if", "pass", "yield", - "break", "except", "import", "print", - "class", "exec", "in", "raise", - "continue", "finally", "is", "return", - "def", "for", "lambda", "try" + "and", + "del", + "from", + "not", + "while", + "as", + "elif", + "global", + "or", + "with", + "assert", + "else", + "if", + "pass", + "yield", + "break", + "except", + "import", + "print", + "class", + "exec", + "in", + "raise", + "continue", + "finally", + "is", + "return", + "def", + "for", + "lambda", + "try" ) } - diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala index c02aedc93..c515bd816 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala @@ -3,7 +3,14 @@ package org.bykn.bosatsu.codegen.python import cats.Monad import cats.data.{NonEmptyList, State} import cats.parse.{Parser => P} -import org.bykn.bosatsu.{PackageName, Identifier, Matchless, Par, Parser, RecursionKind} +import org.bykn.bosatsu.{ + PackageName, + Identifier, + Matchless, + Par, + Parser, + RecursionKind +} import org.bykn.bosatsu.rankn.Type import org.typelevel.paiges.Doc @@ -38,18 +45,24 @@ object PythonGen { private object Impl { case class EnvState( - imports: Map[Module, Code.Ident], - bindings: Map[Bindable, (Int, List[Code.Ident])], - tops: Set[Bindable], - nextTmp: Long) { - - private def bindInc(b: Bindable, inc: Int)(fn: Int => Code.Ident): (EnvState, Code.Ident) = { + imports: Map[Module, Code.Ident], + bindings: Map[Bindable, (Int, List[Code.Ident])], + tops: Set[Bindable], + nextTmp: Long + ) { + + private def bindInc(b: Bindable, inc: Int)( + fn: Int => Code.Ident + ): (EnvState, Code.Ident) = { val (c, s) = bindings.getOrElse(b, (0, Nil)) val pname = fn(c) - (copy( - bindings = bindings.updated(b, (c + inc, pname :: s)) - ), pname) + ( + copy( + bindings = bindings.updated(b, (c + inc, pname :: s)) + ), + pname + ) } def bind(b: Bindable): (EnvState, Code.Ident) = @@ -69,11 +82,13 @@ object PythonGen { // see if we are shadowing, or top level bindings.get(b) match { case Some((_, h :: _)) => h - case _ if tops(b) => escape(b) - case other => + case _ if tops(b) => escape(b) + case other => // $COVERAGE-OFF$ - throw new IllegalStateException(s"unexpected deref: $b with bindings: $other") - // $COVERAGE-ON$ + throw new IllegalStateException( + s"unexpected deref: $b with bindings: $other" + ) + // $COVERAGE-ON$ } def unbind(b: Bindable): EnvState = @@ -82,11 +97,14 @@ object PythonGen { copy(bindings = bindings.updated(b, (cnt, tail))) case other => // $COVERAGE-OFF$ - throw new IllegalStateException(s"invalid scope: $other for $b with $bindings") - // $COVERAGE-ON$ + throw new IllegalStateException( + s"invalid scope: $other for $b with $bindings" + ) + // $COVERAGE-ON$ } - def getNextTmp: (EnvState, Long) = (copy(nextTmp = nextTmp + 1L), nextTmp) + def getNextTmp: (EnvState, Long) = + (copy(nextTmp = nextTmp + 1L), nextTmp) def topLevel(b: Bindable): (EnvState, Code.Ident) = (copy(tops = tops + b), escape(b)) @@ -96,13 +114,14 @@ object PythonGen { case Some(alias) => (this, alias) case None => val impNumber = imports.size - val alias = Code.Ident(escapeRaw("___i", mod.last.name + impNumber.toString)) + val alias = Code.Ident( + escapeRaw("___i", mod.last.name + impNumber.toString) + ) (copy(imports = imports.updated(mod, alias)), alias) } def importStatements: List[Code.Import] = - imports - .iterator + imports.iterator .map { case (path, alias) => val modName = path.map(_.name).toList.mkString(".") Code.Import(modName, Some(alias)) @@ -167,7 +186,8 @@ object PythonGen { Monad[Env].pure(Code.Ident(s"___a$long")) def newAssignableVar: Env[Code.Ident] = - Impl.env(_.getNextTmp) + Impl + .env(_.getNextTmp) .map { long => Code.Ident(s"___t$long") } @@ -186,8 +206,14 @@ object PythonGen { def topLevelName(n: Bindable): Env[Code.Ident] = Impl.env(_.topLevel(n)) - def onLastsM(cs: List[ValueLike])(fn: List[Expression] => Env[ValueLike]): Env[ValueLike] = { - def loop(cs: List[ValueLike], setup: List[Statement], args: List[Expression]): Env[ValueLike] = + def onLastsM( + cs: List[ValueLike] + )(fn: List[Expression] => Env[ValueLike]): Env[ValueLike] = { + def loop( + cs: List[ValueLike], + setup: List[Statement], + args: List[Expression] + ): Env[ValueLike] = cs match { case Nil => val res = fn(args.reverse) @@ -195,11 +221,11 @@ object PythonGen { case None => res case Some(nel) => val stmts = nel.reverse - val stmt = Code.block(stmts.head, stmts.tail :_*) + val stmt = Code.block(stmts.head, stmts.tail: _*) res.map(stmt.withValue(_)) } - case (e: Expression) :: t => loop(t, setup, e :: args) - case (ifelse@IfElse(_, _)) :: tail => + case (e: Expression) :: t => loop(t, setup, e :: args) + case (ifelse @ IfElse(_, _)) :: tail => // we allocate a result and assign // the result on each value Env.newAssignableVar.flatMap { v => @@ -212,31 +238,44 @@ object PythonGen { loop(cs, Nil, Nil) } - def onLasts(cs: List[ValueLike])(fn: List[Expression] => ValueLike): Env[ValueLike] = + def onLasts(cs: List[ValueLike])( + fn: List[Expression] => ValueLike + ): Env[ValueLike] = onLastsM(cs)(fn.andThen(Monad[Env].pure(_))) - def onLastM(c: ValueLike)(fn: Expression => Env[ValueLike]): Env[ValueLike] = + def onLastM( + c: ValueLike + )(fn: Expression => Env[ValueLike]): Env[ValueLike] = onLastsM(c :: Nil) { case x :: Nil => fn(x) - case other => + case other => // $COVERAGE-OFF$ - throw new IllegalStateException(s"expected list to have size 1: $other") - // $COVERAGE-ON$ + throw new IllegalStateException( + s"expected list to have size 1: $other" + ) + // $COVERAGE-ON$ } def onLast(c: ValueLike)(fn: Expression => ValueLike): Env[ValueLike] = onLastM(c)(fn.andThen(Monad[Env].pure(_))) - def onLast2(c1: ValueLike, c2: ValueLike)(fn: (Expression, Expression) => ValueLike): Env[ValueLike] = + def onLast2(c1: ValueLike, c2: ValueLike)( + fn: (Expression, Expression) => ValueLike + ): Env[ValueLike] = onLasts(c1 :: c2 :: Nil) { case x1 :: x2 :: Nil => fn(x1, x2) - case other => + case other => // $COVERAGE-OFF$ - throw new IllegalStateException(s"expected list to have size 2: $other") - // $COVERAGE-ON$ + throw new IllegalStateException( + s"expected list to have size 2: $other" + ) + // $COVERAGE-ON$ } - def ifElse(conds: NonEmptyList[(ValueLike, ValueLike)], elseV: ValueLike): Env[ValueLike] = { + def ifElse( + conds: NonEmptyList[(ValueLike, ValueLike)], + elseV: ValueLike + ): Env[ValueLike] = { // for all the non-expression conditions, we need to defer evaluating them // until they are really needed conds match { @@ -273,9 +312,16 @@ object PythonGen { } } - def ifElseS(cond: ValueLike, thenS: Statement, elseS: Statement): Env[Statement] = + def ifElseS( + cond: ValueLike, + thenS: Statement, + elseS: Statement + ): Env[Statement] = cond match { - case x: Expression => Monad[Env].pure(ifStatement(NonEmptyList.one((x, thenS)), Some(elseS))) + case x: Expression => + Monad[Env].pure( + ifStatement(NonEmptyList.one((x, thenS)), Some(elseS)) + ) case WithValue(stmt, vl) => ifElseS(vl, thenS, elseS).map(stmt +: _) case v => @@ -292,7 +338,8 @@ object PythonGen { def andCode(c1: ValueLike, c2: ValueLike): Env[ValueLike] = (c1, c2) match { - case (t: Expression, c2) if t.simplify == Code.Const.True => Monad[Env].pure(c2) + case (t: Expression, c2) if t.simplify == Code.Const.True => + Monad[Env].pure(c2) case (_, x2: Expression) => onLast(c1)(_.evalAnd(x2)) case _ => @@ -305,7 +352,8 @@ object PythonGen { res <- Env.newAssignableVar ifstmt <- ifElseS(x1, res := c2, Code.Pass) } yield { - Code.block( + Code + .block( res := Code.Const.False, ifstmt ) @@ -314,54 +362,67 @@ object PythonGen { } } - def makeDef(defName: Code.Ident, arg: NonEmptyList[Code.Ident], v: ValueLike): Code.Def = + def makeDef( + defName: Code.Ident, + arg: NonEmptyList[Code.Ident], + v: ValueLike + ): Code.Def = Code.Def(defName, arg.toList, toReturn(v)) - def replaceTailCallWithAssign(name: Ident, argSize: Int, body: ValueLike)(onArgs: List[Expression] => Statement): Env[ValueLike] = { + def replaceTailCallWithAssign(name: Ident, argSize: Int, body: ValueLike)( + onArgs: List[Expression] => Statement + ): Env[ValueLike] = { val initBody = body def loop(body: ValueLike): Env[ValueLike] = body match { - case a@Apply(fn0, args0) => + case a @ Apply(fn0, args0) => if (fn0 == name) { if (args0.length == argSize) { val all = onArgs(args0) // set all the values and return the empty tuple Monad[Env].pure(all.withValue(Code.Const.Unit)) - } - else { + } else { // $COVERAGE-OFF$ - throw new IllegalStateException(s"expected a tailcall for $name in $initBody, but found: $a") + throw new IllegalStateException( + s"expected a tailcall for $name in $initBody, but found: $a" + ) // $COVERAGE-ON$ } - } - else { + } else { Monad[Env].pure(a) } case Parens(p) => loop(p).flatMap(onLast(_)(Parens(_))) case IfElse(ifCases, elseCase) => // only the result types are in tail position, we don't need to recurse on conds - val ifs = ifCases.traverse { case (cond, res) => loop(res).map((cond, _)) } + val ifs = ifCases.traverse { case (cond, res) => + loop(res).map((cond, _)) + } (ifs, loop(elseCase)) .mapN(ifElse(_, _)) .flatten case Ternary(ifTrue, cond, ifFalse) => // both results are in the tail position - (loop(ifTrue), loop(ifFalse)) - .mapN { (t, f) => - ifElse(NonEmptyList.one((cond, t)), f) - } - .flatten + (loop(ifTrue), loop(ifFalse)).mapN { (t, f) => + ifElse(NonEmptyList.one((cond, t)), f) + }.flatten case WithValue(stmt, v) => loop(v).map(stmt.withValue(_)) // the rest cannot have a call in the tail position - case DotSelect(_, _) | Op(_, _, _) | Lambda(_, _) | MakeTuple(_) | MakeList(_) | SelectItem(_, _) | SelectRange(_, _, _) | Ident(_) | PyBool(_) | PyString(_) | PyInt(_) => Monad[Env].pure(body) + case DotSelect(_, _) | Op(_, _, _) | Lambda(_, _) | MakeTuple(_) | + MakeList(_) | SelectItem(_, _) | SelectRange(_, _, _) | Ident(_) | + PyBool(_) | PyString(_) | PyInt(_) => + Monad[Env].pure(body) } loop(initBody) } // these are always recursive so we can use def to define them - def buildLoop(selfName: Ident, fnMutArgs: NonEmptyList[(Ident, Ident)], body: ValueLike): Env[Statement] = { + def buildLoop( + selfName: Ident, + fnMutArgs: NonEmptyList[(Ident, Ident)], + body: ValueLike + ): Env[Statement] = { /* * bodyUpdate = body except App(foo, args) is replaced with * reseting the inputs, and setting cont to True and having @@ -383,17 +444,14 @@ object PythonGen { // we could mutate a variable a later expression depends on // some times we generate code that does x = x, remove those cases val (left, right) = - mutArgs.toList.zip(args) - .filter { case (x, y) => x != y } - .unzip + mutArgs.toList.zip(args).filter { case (x, y) => x != y }.unzip Code.block( cont := Const.True, if (left.isEmpty) Pass else if (left.lengthCompare(1) == 0) { left.head := right.head - } - else { + } else { (MakeTuple(left) := MakeTuple(right)) } ) @@ -404,7 +462,9 @@ object PythonGen { ac = assignMut(cont)(fnArgs.toList) res <- Env.newAssignableVar ar = Assign(res, Code.Const.Unit) - body1 <- replaceTailCallWithAssign(selfName, mutArgs.length, body)(assignMut(cont)) + body1 <- replaceTailCallWithAssign(selfName, mutArgs.length, body)( + assignMut(cont) + ) setRes = res := body1 loop = While(cont, Assign(cont, Const.False) +: setRes) newBody = (ac +: ar +: loop).withValue(res) @@ -413,10 +473,10 @@ object PythonGen { } - private[this] val base62Items = (('0' to '9') ++ ('A' to 'Z') ++ ('a' to 'z')).toSet + private[this] val base62Items = + (('0' to '9') ++ ('A' to 'Z') ++ ('a' to 'z')).toSet private def toBase62(c: Char): String = - if (base62Items(c)) c.toString else if (c == '_') "__" else { @@ -425,8 +485,7 @@ object PythonGen { // $COVERAGE-OFF$ sys.error(s"invalid in: $i0") // $COVERAGE-ON$ - } - else if (i0 < 10) (i0 + '0'.toInt).toChar + } else if (i0 < 10) (i0 + '0'.toInt).toChar else if (i0 < 36) (i0 - 10 + 'A'.toInt).toChar else if (i0 < 62) (i0 - 36 + 'a'.toInt).toChar else { @@ -450,11 +509,15 @@ object PythonGen { private def escapeRaw(prefix: String, str: String): String = str.map(toBase62).mkString(prefix, "", "") - private def unBase62(str: String, offset: Int, bldr: java.lang.StringBuilder): Int = { + private def unBase62( + str: String, + offset: Int, + bldr: java.lang.StringBuilder + ): Int = { var idx = offset var num = 0 - while(idx < str.length) { + while (idx < str.length) { val c = str.charAt(idx) idx += 1 if (c == '_') { @@ -463,14 +526,12 @@ object PythonGen { val numC = num.toChar bldr.append(numC) return (idx - offset) - } - else { + } else { // "__" decodes to "_" bldr.append('_') return (idx - offset) } - } - else { + } else { val base = if (c <= '9') '0'.toInt else if (c <= 'Z') ('A'.toInt - 10) @@ -494,7 +555,10 @@ object PythonGen { // ___b: shadowable (internal) names def escape(n: Bindable): Code.Ident = { val str = n.asString - if (!str.startsWith("___") && Code.python2Name.matcher(str).matches && !Code.pyKeywordList(str)) Code.Ident(str) + if ( + !str.startsWith("___") && Code.python2Name.matcher(str).matches && !Code + .pyKeywordList(str) + ) Code.Ident(str) else { // we need to escape Code.Ident(escapeRaw("___n", str)) @@ -502,7 +566,10 @@ object PythonGen { } def escapeModule(str: String): Code.Ident = { - if (!str.startsWith("___") && Code.python2Name.matcher(str).matches && !Code.pyKeywordList(str)) Code.Ident(str) + if ( + !str.startsWith("___") && Code.python2Name.matcher(str).matches && !Code + .pyKeywordList(str) + ) Code.Ident(str) else { // we need to escape Code.Ident(escapeRaw("___m", str)) @@ -523,15 +590,13 @@ object PythonGen { else { idx += res } - } - else { + } else { bldr.append(c) } } bldr.toString() - } - else { + } else { str } @@ -543,16 +608,18 @@ object PythonGen { } } - /** - * Remap is used to handle remapping external values - */ - private def apply(packName: PackageName, name: Bindable, me: Expr)(remap: (PackageName, Bindable) => Env[Option[ValueLike]]): Env[Statement] = { + /** Remap is used to handle remapping external values + */ + private def apply(packName: PackageName, name: Bindable, me: Expr)( + remap: (PackageName, Bindable) => Env[Option[ValueLike]] + ): Env[Statement] = { val ops = new Impl.Ops(packName, remap) // if we have a top level let rec with the same name, handle it more cleanly val nmVeEnv = me match { - case Let(Right((n1, RecursionKind.NonRecursive)), inner, Local(n2)) if ((n1 === name) && (n2 === name)) => + case Let(Right((n1, RecursionKind.NonRecursive)), inner, Local(n2)) + if ((n1 === name) && (n2 === name)) => // we can just bind now at the top level for { nm <- Env.topLevelName(name) @@ -579,10 +646,11 @@ object PythonGen { // def test_all(self): // # iterate through making assertions // - (Env.importLiteral(NonEmptyList.one(Code.Ident("unittest"))), + ( + Env.importLiteral(NonEmptyList.one(Code.Ident("unittest"))), Env.newAssignableVar, Env.topLevelName(name) - ) + ) .mapN { (importedName, tmpVar, testName) => import Impl._ @@ -595,48 +663,54 @@ object PythonGen { // Assertion(bool, msg) val testAssertion: Code.Statement = - Code.Call(Code.Apply(selfName.dot(Code.Ident("assertTrue")), - argName.get(1) :: argName.get(2) :: Nil)) + Code.Call( + Code.Apply( + selfName.dot(Code.Ident("assertTrue")), + argName.get(1) :: argName.get(2) :: Nil + ) + ) // TestSuite(suiteName, tests) val testSuite: Code.Statement = Code.block( tmpVar := argName.get(2), // get the test list - Code.While(isNonEmpty(tmpVar), + Code.While( + isNonEmpty(tmpVar), Code.block( Code.Call(Code.Apply(loopName, headList(tmpVar) :: Nil)), tmpVar := tailList(tmpVar) ) ) - ) + ) val loopBody: Code.Statement = Code.IfStatement( NonEmptyList.one((isAssertion, testAssertion)), - Some(testSuite)) + Some(testSuite) + ) val recTest = - Code.Def( - loopName, - argName :: Nil, - loopBody) + Code.Def(loopName, argName :: Nil, loopBody) val body = - Code.block( - recTest, - Code.Call(Code.Apply(loopName, testName :: Nil))) + Code.block(recTest, Code.Call(Code.Apply(loopName, testName :: Nil))) val defBody = - Code.Def(Code.Ident("test_all"), - selfName :: Nil, - body) + Code.Def(Code.Ident("test_all"), selfName :: Nil, body) - Code.ClassDef(Code.Ident("BosatsuTests"), List(importedName.dot(Code.Ident("TestCase"))), - defBody) + Code.ClassDef( + Code.Ident("BosatsuTests"), + List(importedName.dot(Code.Ident("TestCase"))), + defBody + ) } } - private def addMainEval(name: Bindable, mod: Module, ci: Code.Ident): Env[Statement] = + private def addMainEval( + name: Bindable, + mod: Module, + ci: Code.Ident + ): Env[Statement] = /* * this does: * if __name__ == "__main__": @@ -669,7 +743,10 @@ object PythonGen { Parser.dictLikeParser(Identifier.bindableParser, modParser) val outer: P[List[(PackageName, List[(Bindable, (Module, Code.Ident))])]] = - Parser.maybeSpacesAndLines.with1 *> Parser.dictLikeParser(PackageName.parser, inner) <* Parser.maybeSpacesAndLines + Parser.maybeSpacesAndLines.with1 *> Parser.dictLikeParser( + PackageName.parser, + inner + ) <* Parser.maybeSpacesAndLines outer.map { items => items.flatMap { case (p, bs) => @@ -681,27 +758,31 @@ object PythonGen { // parses a map of of evaluators // { fullyqualifiedType: foo.bar.baz, } val evaluatorParser: P[List[(Type, (Module, Code.Ident))]] = - Parser.maybeSpacesAndLines.with1 *> Parser.dictLikeParser(Type.fullyResolvedParser, modParser) <* Parser.maybeSpacesAndLines + Parser.maybeSpacesAndLines.with1 *> Parser.dictLikeParser( + Type.fullyResolvedParser, + modParser + ) <* Parser.maybeSpacesAndLines // compile a set of packages given a set of external remappings def renderAll( - pm: Map[PackageName, List[(Bindable, Expr)]], - externals: Map[(PackageName, Bindable), (Module, Code.Ident)], - tests: Map[PackageName, Bindable], - evaluators: Map[PackageName, (Bindable, Module, Code.Ident)])(implicit ec: Par.EC): Map[PackageName, (Module, Doc)] = { - - val externalRemap: (PackageName, Bindable) => Env[Option[ValueLike]] = - { (p, b) => + pm: Map[PackageName, List[(Bindable, Expr)]], + externals: Map[(PackageName, Bindable), (Module, Code.Ident)], + tests: Map[PackageName, Bindable], + evaluators: Map[PackageName, (Bindable, Module, Code.Ident)] + )(implicit ec: Par.EC): Map[PackageName, (Module, Doc)] = { + + val externalRemap: (PackageName, Bindable) => Env[Option[ValueLike]] = { + (p, b) => externals.get((p, b)) match { case None => Monad[Env].pure(None) case Some((m, i)) => - Env.importLiteral(m) + Env + .importLiteral(m) .map { alias => Some(Code.DotSelect(alias, i)) } } - } + } - val all = pm - .toList + val all = pm.toList .traverse { case (p, lets) => Par.start { val stmts0: Env[List[Statement]] = @@ -711,7 +792,9 @@ object PythonGen { } val evalStmt: Env[Option[Statement]] = - evaluators.get(p).traverse { case (b, m, c) => addMainEval(b, m, c) } + evaluators.get(p).traverse { case (b, m, c) => + addMainEval(b, m, c) + } val testStmt: Env[Option[Statement]] = tests.get(p).traverse(addUnitTest) @@ -760,72 +843,114 @@ object PythonGen { lst.get(2) object PredefExternal { - private val cmpFn: List[ValueLike] => Env[ValueLike] = { - input => - Env.onLast2(input.head, input.tail.head) { (arg0, arg1) => - // 0 if arg0 < arg1 else ( - // 1 if arg0 == arg1 else 2 - // ) + private val cmpFn: List[ValueLike] => Env[ValueLike] = { input => + Env.onLast2(input.head, input.tail.head) { (arg0, arg1) => + // 0 if arg0 < arg1 else ( + // 1 if arg0 == arg1 else 2 + // ) + Code + .Ternary( + Code.fromInt(0), + Code.Op(arg0, Code.Const.Lt, arg1), Code.Ternary( - Code.fromInt(0), - Code.Op(arg0, Code.Const.Lt, arg1), - Code.Ternary( - Code.fromInt(1), - Code.Op(arg0, Code.Const.Eq, arg1), - Code.fromInt(2))).simplify - } + Code.fromInt(1), + Code.Op(arg0, Code.Const.Eq, arg1), + Code.fromInt(2) + ) + ) + .simplify + } } val results: Map[Bindable, (List[ValueLike] => Env[ValueLike], Int)] = Map( - (Identifier.unsafeBindable("add"), + ( + Identifier.unsafeBindable("add"), ( - input => Env.onLast2(input.head, input.tail.head)(_.evalPlus(_)) - , 2)), - (Identifier.unsafeBindable("sub"), - ({ - input => Env.onLast2(input.head, input.tail.head)(_.evalMinus(_)) - } , 2)), - (Identifier.unsafeBindable("times"), - ({ - input => Env.onLast2(input.head, input.tail.head)(_.evalTimes(_)) - }, 2)), - (Identifier.unsafeBindable("div"), - ({ - input => Env.onLast2(input.head, input.tail.head) { (a, b) => - Code.Ternary( - Code.Op(a, Code.Const.Div, b), - b, // 0 is false in python - Code.fromInt(0) - ).simplify - } - }, 2)), - (Identifier.unsafeBindable("mod_Int"), - ({ - input => Env.onLast2(input.head, input.tail.head) { (a, b) => - Code.Ternary( - Code.Op(a, Code.Const.Mod, b), - b, // 0 is false in python - a - ).simplify - } - }, 2)), + input => Env.onLast2(input.head, input.tail.head)(_.evalPlus(_)), + 2 + ) + ), + ( + Identifier.unsafeBindable("sub"), + ( + { input => + Env.onLast2(input.head, input.tail.head)(_.evalMinus(_)) + }, + 2 + ) + ), + ( + Identifier.unsafeBindable("times"), + ( + { input => + Env.onLast2(input.head, input.tail.head)(_.evalTimes(_)) + }, + 2 + ) + ), + ( + Identifier.unsafeBindable("div"), + ( + { input => + Env.onLast2(input.head, input.tail.head) { (a, b) => + Code + .Ternary( + Code.Op(a, Code.Const.Div, b), + b, // 0 is false in python + Code.fromInt(0) + ) + .simplify + } + }, + 2 + ) + ), + ( + Identifier.unsafeBindable("mod_Int"), + ( + { input => + Env.onLast2(input.head, input.tail.head) { (a, b) => + Code + .Ternary( + Code.Op(a, Code.Const.Mod, b), + b, // 0 is false in python + a + ) + .simplify + } + }, + 2 + ) + ), (Identifier.unsafeBindable("cmp_Int"), (cmpFn, 2)), - (Identifier.unsafeBindable("eq_Int"), - ({ - input => Env.onLast2(input.head, input.tail.head)(_.eval(Code.Const.Eq, _)) - }, 2)), - - (Identifier.unsafeBindable("gcd_Int"), - ({ - input => - (Env.newAssignableVar, Env.newAssignableVar, Env.newAssignableVar) - .mapN { (tmpa, tmpb, tmpc) => - Env.onLast2(input.head, input.tail.head) { (a, b) => - Code.block( + ( + Identifier.unsafeBindable("eq_Int"), + ( + { input => + Env.onLast2(input.head, input.tail.head)( + _.eval(Code.Const.Eq, _) + ) + }, + 2 + ) + ), + ( + Identifier.unsafeBindable("gcd_Int"), + ( + { input => + ( + Env.newAssignableVar, + Env.newAssignableVar, + Env.newAssignableVar + ).mapN { (tmpa, tmpb, tmpc) => + Env.onLast2(input.head, input.tail.head) { (a, b) => + Code + .block( tmpa := a, tmpb := b, - Code.While(tmpb, + Code.While( + tmpb, Code.block( tmpc := tmpb, // we know b != 0 because we are in the while loop @@ -836,152 +961,209 @@ object PythonGen { ) ) .withValue(tmpa) - } } - .flatten - }, 2)), - //external def int_loop(intValue: Int, state: a, fn: (Int, a) -> TupleCons[Int, TupleCons[a, Unit]]) -> a - // def int_loop(i, a, fn): - // if i <= 0: a - // else: - // (i1, a1) = fn(i, a) - // if i <= i1: a - // else int_loop(i1, a, fn) - // - // def int_loop(i, a, fn): - // cont = (0 < i) - // res = a - // _i = i - // _a = a - // while cont: - // res = fn(_i, _a) - // tmp_i = res[0] - // _a = res[1][0] - // cont = (0 < tmp_i) and (tmp_i < _i) - // _i = tmp_i - // return _a - (Identifier.unsafeBindable("int_loop"), - ({ - input => - (Env.newAssignableVar, Env.newAssignableVar, Env.newAssignableVar, Env.newAssignableVar, Env.newAssignableVar) - .tupled + }.flatten + }, + 2 + ) + ), + // external def int_loop(intValue: Int, state: a, fn: (Int, a) -> TupleCons[Int, TupleCons[a, Unit]]) -> a + // def int_loop(i, a, fn): + // if i <= 0: a + // else: + // (i1, a1) = fn(i, a) + // if i <= i1: a + // else int_loop(i1, a, fn) + // + // def int_loop(i, a, fn): + // cont = (0 < i) + // res = a + // _i = i + // _a = a + // while cont: + // res = fn(_i, _a) + // tmp_i = res[0] + // _a = res[1][0] + // cont = (0 < tmp_i) and (tmp_i < _i) + // _i = tmp_i + // return _a + ( + Identifier.unsafeBindable("int_loop"), + ( + { input => + ( + Env.newAssignableVar, + Env.newAssignableVar, + Env.newAssignableVar, + Env.newAssignableVar, + Env.newAssignableVar + ).tupled .flatMap { case (cont, res, _i, _a, tmp_i) => Env.onLasts(input) { case i :: a :: fn :: Nil => - Code.block( - cont := Code.Op(Code.fromInt(0), Code.Const.Lt, i), - res := a, - _i := i, - _a := a, - Code.While(cont, - Code.block( - res := fn(_i, _a), - tmp_i := res.get(0), - _a := res.get(1).get(0), - cont := Code.Op(Code.Op(Code.fromInt(0), Code.Const.Lt, tmp_i), - Code.Const.And, - Code.Op(tmp_i, Code.Const.Lt, _i)), - _i := tmp_i + Code + .block( + cont := Code.Op(Code.fromInt(0), Code.Const.Lt, i), + res := a, + _i := i, + _a := a, + Code.While( + cont, + Code.block( + res := fn(_i, _a), + tmp_i := res.get(0), + _a := res.get(1).get(0), + cont := Code.Op( + Code + .Op(Code.fromInt(0), Code.Const.Lt, tmp_i), + Code.Const.And, + Code.Op(tmp_i, Code.Const.Lt, _i) + ), + _i := tmp_i ) - ) - ) - .withValue(_a) + ) + ) + .withValue(_a) case other => // $COVERAGE-OFF$ - throw new IllegalStateException(s"expected arity 3 got: $other") - // $COVERAGE-ON$ + throw new IllegalStateException( + s"expected arity 3 got: $other" + ) + // $COVERAGE-ON$ } } - }, 3)), - (Identifier.unsafeBindable("concat_String"), - ({ input => - Env.onLastM(input.head) { listOfStrings => - // convert to python list, then call "".join(seq) - Env.newAssignableVar - .flatMap { pyList => - bosatsuListToPython(pyList, listOfStrings) - .map { loop => - Code.block( - pyList := Code.MakeList(Nil), - loop - ) - .withValue { - Code.PyString("").dot(Code.Ident("join"))(pyList) - } + }, + 3 + ) + ), + ( + Identifier.unsafeBindable("concat_String"), + ( + { input => + Env.onLastM(input.head) { listOfStrings => + // convert to python list, then call "".join(seq) + Env.newAssignableVar + .flatMap { pyList => + bosatsuListToPython(pyList, listOfStrings) + .map { loop => + Code + .block( + pyList := Code.MakeList(Nil), + loop + ) + .withValue { + Code.PyString("").dot(Code.Ident("join"))(pyList) + } + } } } - } - }, 1)), - (Identifier.unsafeBindable("int_to_String"), - ({ - input => Env.onLast(input.head) { - case Code.PyInt(i) => Code.PyString(i.toString) - case i => Code.Apply(Code.DotSelect(i, Code.Ident("__str__")), Nil) - } - }, 1)), - (Identifier.unsafeBindable("trace"), - ({ - input => Env.onLast2(input.head, input.tail.head) { (msg, i) => - Code.Call(Code.Apply(Code.Ident("print"), msg :: i :: Nil)) - .withValue(i) - } - }, 2)), - (Identifier.unsafeBindable("partition_String"), - ({ - input => - Env.newAssignableVar - .flatMap { res => - Env.onLast2(input.head, input.tail.head) { (str, sep) => - // if sep == "": None - // else: - // (a, s1, b) = str.partition(sep) - // if s1: (1, (a, (b, ()))) - // else: (0, ) - val a = res.get(0) - val s1 = res.get(1) - val b = res.get(2) - val success = Code.MakeTuple(Code.fromInt(1) :: - Code.MakeTuple(a :: Code.MakeTuple(b :: Code.Const.Unit :: Nil) :: Nil) :: - Nil - ) - val fail = Code.MakeTuple(Code.fromInt(0) :: Nil) - val nonEmpty = - (res := str.dot(Code.Ident("partition"))(sep)) - .withValue(Code.Ternary(success, s1, fail)) - - Code.IfElse(NonEmptyList.one((sep, nonEmpty)), fail) + }, + 1 + ) + ), + ( + Identifier.unsafeBindable("int_to_String"), + ( + { input => + Env.onLast(input.head) { + case Code.PyInt(i) => Code.PyString(i.toString) + case i => + Code.Apply(Code.DotSelect(i, Code.Ident("__str__")), Nil) } - } - }, 2)), - (Identifier.unsafeBindable("rpartition_String"), - ({ - input => - Env.newAssignableVar - .flatMap { res => - Env.onLast2(input.head, input.tail.head) { (str, sep) => - // (a, s1, b) = str.partition(sep) - // if s1: (1, (a, (b, ()))) - // else: (0, ) - val a = res.get(0) - val s1 = res.get(1) - val b = res.get(2) - val success = Code.MakeTuple(Code.fromInt(1) :: - Code.MakeTuple(a :: Code.MakeTuple(b :: Code.Const.Unit :: Nil) :: Nil) :: - Nil - ) - val fail = Code.MakeTuple(Code.fromInt(0) :: Nil) - val nonEmpty = - (res := str.dot(Code.Ident("rpartition"))(sep)) - .withValue(Code.Ternary(success, s1, fail)) - - Code.IfElse(NonEmptyList.one((sep, nonEmpty)), fail) + }, + 1 + ) + ), + ( + Identifier.unsafeBindable("trace"), + ( + { input => + Env.onLast2(input.head, input.tail.head) { (msg, i) => + Code + .Call(Code.Apply(Code.Ident("print"), msg :: i :: Nil)) + .withValue(i) } - } - }, 2)), - (Identifier.unsafeBindable("string_Order_fn"), (cmpFn, 2)) - ) + }, + 2 + ) + ), + ( + Identifier.unsafeBindable("partition_String"), + ( + { input => + Env.newAssignableVar + .flatMap { res => + Env.onLast2(input.head, input.tail.head) { (str, sep) => + // if sep == "": None + // else: + // (a, s1, b) = str.partition(sep) + // if s1: (1, (a, (b, ()))) + // else: (0, ) + val a = res.get(0) + val s1 = res.get(1) + val b = res.get(2) + val success = Code.MakeTuple( + Code.fromInt(1) :: + Code.MakeTuple( + a :: Code.MakeTuple( + b :: Code.Const.Unit :: Nil + ) :: Nil + ) :: + Nil + ) + val fail = Code.MakeTuple(Code.fromInt(0) :: Nil) + val nonEmpty = + (res := str.dot(Code.Ident("partition"))(sep)) + .withValue(Code.Ternary(success, s1, fail)) + + Code.IfElse(NonEmptyList.one((sep, nonEmpty)), fail) + } + } + }, + 2 + ) + ), + ( + Identifier.unsafeBindable("rpartition_String"), + ( + { input => + Env.newAssignableVar + .flatMap { res => + Env.onLast2(input.head, input.tail.head) { (str, sep) => + // (a, s1, b) = str.partition(sep) + // if s1: (1, (a, (b, ()))) + // else: (0, ) + val a = res.get(0) + val s1 = res.get(1) + val b = res.get(2) + val success = Code.MakeTuple( + Code.fromInt(1) :: + Code.MakeTuple( + a :: Code.MakeTuple( + b :: Code.Const.Unit :: Nil + ) :: Nil + ) :: + Nil + ) + val fail = Code.MakeTuple(Code.fromInt(0) :: Nil) + val nonEmpty = + (res := str.dot(Code.Ident("rpartition"))(sep)) + .withValue(Code.Ternary(success, s1, fail)) + + Code.IfElse(NonEmptyList.one((sep, nonEmpty)), fail) + } + } + }, + 2 + ) + ), + (Identifier.unsafeBindable("string_Order_fn"), (cmpFn, 2)) + ) - def bosatsuListToPython(pyList: Code.Ident, bList: Expression): Env[Statement] = + def bosatsuListToPython( + pyList: Code.Ident, + bList: Expression + ): Env[Statement] = Env.newAssignableVar .map { tmp => // tmp = bList @@ -991,7 +1173,8 @@ object PythonGen { // tmp = tmp[2] Code.block( tmp := bList, - Code.While(isNonEmpty(tmp), + Code.While( + isNonEmpty(tmp), Code.block( Code.Call(pyList.dot(Code.Ident("append"))(headList(tmp))), tmp := tailList(tmp) @@ -1000,13 +1183,17 @@ object PythonGen { ) } - def unapply(expr: Expr): Option[(List[ValueLike] => Env[ValueLike], Int)] = + def unapply( + expr: Expr + ): Option[(List[ValueLike] => Env[ValueLike], Int)] = expr match { case Global(PackageName.PredefName, name) => results.get(name) - case _ => None + case _ => None } - def makeLambda(arity: Int)(fn: List[ValueLike] => Env[ValueLike]): Env[ValueLike] = + def makeLambda( + arity: Int + )(fn: List[ValueLike] => Env[ValueLike]): Env[ValueLike] = for { vars <- (1 to arity).toList.traverse(_ => Env.newAssignableVar) body <- fn(vars) @@ -1017,7 +1204,10 @@ object PythonGen { } yield res } - class Ops(packName: PackageName, remap: (PackageName, Bindable) => Env[Option[ValueLike]]) { + class Ops( + packName: PackageName, + remap: (PackageName, Bindable) => Env[Option[ValueLike]] + ) { /* * enums with no fields are integers * enums and structs are tuples @@ -1041,9 +1231,9 @@ object PythonGen { Env.onLasts(vExpr :: args)(Code.MakeTuple(_)) } case MakeStruct(arity) => - if (arity == 0) Monad[Env].pure(Code.Const.Unit) - else if (arity == 1) Monad[Env].pure(args.head) - else Env.onLasts(args)(Code.MakeTuple(_)) + if (arity == 0) Monad[Env].pure(Code.Const.Unit) + else if (arity == 1) Monad[Env].pure(args.head) + else Env.onLasts(args)(Code.MakeTuple(_)) case ZeroNat => Monad[Env].pure(Code.Const.Zero) case SuccNat => @@ -1058,8 +1248,7 @@ object PythonGen { // $COVERAGE-OFF$ throw new IllegalStateException(s"invalid arity $sz for $ce") // $COVERAGE-ON$ - } - else { + } else { // this is the case where we are using the constructor like a function assert(args.isEmpty) for { @@ -1079,7 +1268,9 @@ object PythonGen { ix match { case EqualsLit(expr, lit) => val literal = Code.litToExpr(lit) - loop(expr).flatMap(Env.onLast(_) { ex => Code.Op(ex, Code.Const.Eq, literal) }) + loop(expr).flatMap(Env.onLast(_) { ex => + Code.Op(ex, Code.Const.Eq, literal) + }) case EqualsNat(nat, zeroOrSucc) => val natF = loop(nat) @@ -1109,56 +1300,71 @@ object PythonGen { if (useInts) { // this is represented as an integer Code.Op(t, Code.Const.Eq, idxExpr) - } - else + } else Code.Op(t.get(0), Code.Const.Eq, idxExpr) } } case SetMut(LocalAnonMut(mut), expr) => - (Env.nameForAnon(mut), loop(expr)) - .mapN { (ident, result) => - Env.onLast(result) { resx => - (ident := resx).withValue(Code.Const.True) - } + (Env.nameForAnon(mut), loop(expr)).mapN { (ident, result) => + Env.onLast(result) { resx => + (ident := resx).withValue(Code.Const.True) } - .flatten + }.flatten case MatchString(str, pat, binds) => - (loop(str), binds.traverse { case LocalAnonMut(m) => Env.nameForAnon(m) }) - .mapN { (strVL, binds) => - Env.onLastM(strVL)(matchString(_, pat, binds)) - } - .flatten + ( + loop(str), + binds.traverse { case LocalAnonMut(m) => Env.nameForAnon(m) } + ).mapN { (strVL, binds) => + Env.onLastM(strVL)(matchString(_, pat, binds)) + }.flatten case SearchList(locMut, init, check, optLeft) => // check to see if we can find a non-empty // list that matches check - (loop(init), boolExpr(check)) - .mapN { (initVL, checkVL) => - searchList(locMut, initVL, checkVL, optLeft) - } - .flatten + (loop(init), boolExpr(check)).mapN { (initVL, checkVL) => + searchList(locMut, initVL, checkVL, optLeft) + }.flatten } - def matchString(strEx: Expression, pat: List[StrPart], binds: List[Code.Ident]): Env[ValueLike] = { + def matchString( + strEx: Expression, + pat: List[StrPart], + binds: List[Code.Ident] + ): Env[ValueLike] = { import StrPart.{LitStr, Glob} val bindArray = binds.toArray // return a value like expression that contains the boolean result // and assigns all the bindings along the way - def loop(offsetIdent: Code.Ident, pat: List[StrPart], next: Int): Env[ValueLike] = + def loop( + offsetIdent: Code.Ident, + pat: List[StrPart], + next: Int + ): Env[ValueLike] = pat match { case Nil => - //offset == str.length - Monad[Env].pure(Code.Op(offsetIdent, Code.Const.Eq, strEx.dot(Code.Ident("__len__"))())) + // offset == str.length + Monad[Env].pure( + Code.Op( + offsetIdent, + Code.Const.Eq, + strEx.dot(Code.Ident("__len__"))() + ) + ) case LitStr(expect) :: tail => - //val len = expect.length - //str.regionMatches(offset, expect, 0, len) && loop(offset + len, tail, next) + // val len = expect.length + // str.regionMatches(offset, expect, 0, len) && loop(offset + len, tail, next) // // strEx.startswith(expect, offsetIdent) loop(offsetIdent, tail, next) .flatMap { loopRes => - val regionMatches = strEx.dot(Code.Ident("startswith"))(Code.PyString(expect), offsetIdent) + val regionMatches = strEx.dot(Code.Ident("startswith"))( + Code.PyString(expect), + offsetIdent + ) val rest = ( - offsetIdent := (offsetIdent.evalPlus(Code.fromInt(expect.length))) + offsetIdent := (offsetIdent.evalPlus( + Code.fromInt(expect.length) + )) ).withValue(loopRes) Env.andCode(regionMatches, rest) @@ -1170,10 +1376,10 @@ object PythonGen { Monad[Env].pure( if (h.capture) { // b = str[offset:] - (bindArray(next) := Code.SelectRange(strEx, Some(offsetIdent), None)) + (bindArray(next) := Code + .SelectRange(strEx, Some(offsetIdent), None)) .withValue(Code.Const.True) - } - else Code.Const.True + } else Code.Const.True ) case LitStr(expect) :: tail2 => // here we have to make a loop @@ -1213,67 +1419,87 @@ object PythonGen { } } result - */ - (Env.newAssignableVar, Env.newAssignableVar, Env.newAssignableVar, Env.newAssignableVar) - .mapN { (start, result, candidate, candOffset) => - val searchEnv = loop(candOffset, tail2, next1) - - def onSearch(search: ValueLike): Env[Statement] = - Env.ifElseS(search, - { - // we have matched - val capture = if (h.capture) (bindArray(next) := Code.SelectRange(strEx, Some(offsetIdent), Some(candidate))) else Code.Pass - Code.block( - capture, - result := Code.Const.True, + */ + ( + Env.newAssignableVar, + Env.newAssignableVar, + Env.newAssignableVar, + Env.newAssignableVar + ).mapN { (start, result, candidate, candOffset) => + val searchEnv = loop(candOffset, tail2, next1) + + def onSearch(search: ValueLike): Env[Statement] = + Env.ifElseS( + search, { + // we have matched + val capture = + if (h.capture) + (bindArray(next) := Code.SelectRange( + strEx, + Some(offsetIdent), + Some(candidate) + )) + else Code.Pass + Code.block( + capture, + result := Code.Const.True, + start := Code.fromInt(-1) + ) + }, { + // we couldn't match at start, advance just after the + // candidate + start := candidate.evalPlus(Code.fromInt(1)) + } + ) + + def findBranch(search: ValueLike): Env[Statement] = + onSearch(search) + .flatMap { onS => + Env.ifElseS( + Code + .Op(candidate, Code.Const.Gt, Code.fromInt(-1)), { + // update candidate and search + Code.block( + candOffset := Code.Op( + candidate, + Code.Const.Plus, + Code.fromInt(expect.length) + ), + onS + ) + }, { + // else no more candidates start := Code.fromInt(-1) - ) - }, - { - // we couldn't match at start, advance just after the - // candidate - start := candidate.evalPlus(Code.fromInt(1)) - }) - - def findBranch(search: ValueLike): Env[Statement] = - onSearch(search) - .flatMap { onS => - Env.ifElseS( - Code.Op(candidate, Code.Const.Gt, Code.fromInt(-1)), - { - // update candidate and search - Code.block( - candOffset := Code.Op(candidate, Code.Const.Plus, Code.fromInt(expect.length)), - onS) - }, - { - // else no more candidates - start := Code.fromInt(-1) - } - ) - } + } + ) + } - for { - search <- searchEnv - find <- findBranch(search) - } yield - (Code.block( - start := offsetIdent, - result := Code.Const.False, - Code.While(Code.Op(start, Code.Const.Gt, Code.fromInt(-1)), - Code.block( - candidate := strEx.dot(Code.Ident("find"))(Code.PyString(expect), start), - find - ) + for { + search <- searchEnv + find <- findBranch(search) + } yield (Code + .block( + start := offsetIdent, + result := Code.Const.False, + Code.While( + Code.Op(start, Code.Const.Gt, Code.fromInt(-1)), + Code.block( + candidate := strEx.dot(Code.Ident("find"))( + Code.PyString(expect), + start + ), + find ) ) - .withValue(result)) - } - .flatten + ) + .withValue(result)) + }.flatten case (_: Glob) :: _ => // $COVERAGE-OFF$ - throw new IllegalArgumentException(s"pattern: $pat should have been prevented: adjacent globs are not permitted (one is always empty)") - // $COVERAGE-ON$ + throw new IllegalArgumentException( + s"pattern: $pat should have been prevented: adjacent globs are not permitted (one is always empty)" + ) + // $COVERAGE-ON$ } } @@ -1284,7 +1510,12 @@ object PythonGen { } } - def searchList(locMut: LocalAnonMut, initVL: ValueLike, checkVL: ValueLike, optLeft: Option[LocalAnonMut]): Env[ValueLike] = { + def searchList( + locMut: LocalAnonMut, + initVL: ValueLike, + checkVL: ValueLike, + optLeft: Option[LocalAnonMut] + ): Env[ValueLike] = { /* * here is the implementation from MatchlessToValue * @@ -1310,35 +1541,41 @@ object PythonGen { } res } - */ - (Env.nameForAnon(locMut.ident), optLeft.traverse { lm => Env.nameForAnon(lm.ident) }, Env.newAssignableVar, Env.newAssignableVar) - .mapN { (currentList, optLeft, res, tmpList) => + */ + ( + Env.nameForAnon(locMut.ident), + optLeft.traverse { lm => Env.nameForAnon(lm.ident) }, + Env.newAssignableVar, + Env.newAssignableVar + ) + .mapN { (currentList, optLeft, res, tmpList) => Code .block( res := Code.Const.False, tmpList := initVL, optLeft.fold(Code.pass)(_ := emptyList), // we don't match empty lists, so if currentList reaches Empty we are done - Code.While(isNonEmpty(tmpList), + Code.While( + isNonEmpty(tmpList), Code.block( - currentList := tmpList, - res := checkVL, - Code.ifStatement( - NonEmptyList( - (res, (tmpList := emptyList)), - Nil), - Some { - Code.block( - tmpList := tailList(tmpList), - optLeft.fold(Code.pass) { left => - left := consList(headList(currentList), left) - } - ) - }) + currentList := tmpList, + res := checkVL, + Code.ifStatement( + NonEmptyList((res, (tmpList := emptyList)), Nil), + Some { + Code.block( + tmpList := tailList(tmpList), + optLeft.fold(Code.pass) { left => + left := consList(headList(currentList), left) + } + ) + } ) ) - ).withValue(res) - } + ) + ) + .withValue(res) + } } def topLet(name: Code.Ident, expr: Expr, v: ValueLike): Env[Statement] = @@ -1385,18 +1622,19 @@ object PythonGen { case Lambda(_, args, res) => // python closures work the same so we don't // need to worry about what we capture - (args.traverse(Env.bind(_)), loop(res)).mapN { (args, res) => - res match { - case x: Expression => - Monad[Env].pure(Code.Lambda(args.toList, x)) - case v => - for { - defName <- Env.newAssignableVar - defn = Env.makeDef(defName, args, v) - } yield defn.withValue(defName) + (args.traverse(Env.bind(_)), loop(res)) + .mapN { (args, res) => + res match { + case x: Expression => + Monad[Env].pure(Code.Lambda(args.toList, x)) + case v => + for { + defName <- Env.newAssignableVar + defn = Env.makeDef(defName, args, v) + } yield defn.withValue(defName) + } } - } - .flatMap(_ <* args.traverse_(Env.unbind(_))) + .flatMap(_ <* args.traverse_(Env.unbind(_))) case LoopFn(_, thisName, args, body) => // note, thisName is already bound because LoopFn // is a lambda, not a def @@ -1435,33 +1673,32 @@ object PythonGen { if (p == packName) { // This is just a name in the local package Env.topLevelName(n) - } - else { - (Env.importPackage(p), Env.topLevelName(n)).mapN(Code.DotSelect(_, _)) + } else { + (Env.importPackage(p), Env.topLevelName(n)) + .mapN(Code.DotSelect(_, _)) } } - case Local(b) => Env.deref(b) - case LocalAnon(a) => Env.nameForAnon(a) + case Local(b) => Env.deref(b) + case LocalAnon(a) => Env.nameForAnon(a) case LocalAnonMut(m) => Env.nameForAnon(m) case App(PredefExternal((fn, _)), args) => - args - .toList + args.toList .traverse(loop) .flatMap(fn) case App(cons: ConsExpr, args) => args.traverse(loop).flatMap { pxs => makeCons(cons, pxs.toList) } case App(expr, args) => - (loop(expr), args.traverse(loop)) - .mapN { (fn, args) => - Env.onLasts(fn :: args.toList) { - case fn :: args => Code.Apply(fn, args) - case other => - // $COVERAGE-OFF$ - throw new IllegalStateException(s"got $other, expected to match $expr") - // $COVERAGE-ON$ - } + (loop(expr), args.traverse(loop)).mapN { (fn, args) => + Env.onLasts(fn :: args.toList) { + case fn :: args => Code.Apply(fn, args) + case other => + // $COVERAGE-OFF$ + throw new IllegalStateException( + s"got $other, expected to match $expr" + ) + // $COVERAGE-ON$ } - .flatten + }.flatten case Let(localOrBind, value, in) => val inF = loop(in) @@ -1477,8 +1714,7 @@ object PythonGen { wv = tl.withValue(ine) _ <- Env.unbind(b) } yield wv - } - else { + } else { // value b is in scope after ve for { ve <- loop(value) @@ -1491,12 +1727,10 @@ object PythonGen { } case Left(LocalAnon(l)) => // anonymous names never shadow - (Env.nameForAnon(l), loop(value)) - .mapN { (bi, vE) => - (topLet(bi, value, vE), inF) - .mapN(_.withValue(_)) - } - .flatten + (Env.nameForAnon(l), loop(value)).mapN { (bi, vE) => + (topLet(bi, value, vE), inF) + .mapN(_.withValue(_)) + }.flatten } case LetMut(LocalAnonMut(_), in) => @@ -1520,11 +1754,9 @@ object PythonGen { (boolExpr(c), loop(t)).tupled } - (ifsV, loop(last)) - .mapN { (ifs, elseV) => - Env.ifElse(ifs, elseV) - } - .flatten + (ifsV, loop(last)).mapN { (ifs, elseV) => + Env.ifElse(ifs, elseV) + }.flatten case Always(cond, expr) => (boolExpr(cond).map(Code.always), loop(expr)) @@ -1542,8 +1774,7 @@ object PythonGen { if (sz == 1) { // we don't bother to wrap single item structs exprR - } - else { + } else { // structs are just tuples exprR.flatMap { tup => Env.onLast(tup)(_.get(idx)) diff --git a/core/src/main/scala/org/bykn/bosatsu/graph/Dag.scala b/core/src/main/scala/org/bykn/bosatsu/graph/Dag.scala index 1eecae403..9afd19c83 100644 --- a/core/src/main/scala/org/bykn/bosatsu/graph/Dag.scala +++ b/core/src/main/scala/org/bykn/bosatsu/graph/Dag.scala @@ -29,8 +29,7 @@ sealed trait Dag[A] { def toToposorted: Toposort.Success[A] = { val layerMap: Map[Int, SortedSet[A]] = nodes.groupBy(layerOf(_)) - val ls = (0 until layers) - .iterator + val ls = (0 until layers).iterator .map { idx => // by construction all layers have at least 1 item NonEmptyList.fromListUnsafe(layerMap(idx).toList) diff --git a/core/src/main/scala/org/bykn/bosatsu/graph/Memoize.scala b/core/src/main/scala/org/bykn/bosatsu/graph/Memoize.scala index e4cb03aaa..f221e6382 100644 --- a/core/src/main/scala/org/bykn/bosatsu/graph/Memoize.scala +++ b/core/src/main/scala/org/bykn/bosatsu/graph/Memoize.scala @@ -6,11 +6,13 @@ import scala.collection.immutable.SortedMap object Memoize { - /** - * This memoizes using a sorted map (not a hashMap) in a non-threadsafe manner - * returning None, means we cannot compute this function because it loops forever - */ - def memoizeSorted[A: Ordering, B](fn: (A, A => Option[B]) => Option[B]): A => Option[B] = { + /** This memoizes using a sorted map (not a hashMap) in a non-threadsafe + * manner returning None, means we cannot compute this function because it + * loops forever + */ + def memoizeSorted[A: Ordering, B]( + fn: (A, A => Option[B]) => Option[B] + ): A => Option[B] = { var cache = SortedMap.empty[A, Option[B]] new Function[A, Option[B]] { self => @@ -29,10 +31,9 @@ object Memoize { } } - /** - * This memoizes using a hash map in a non-threadsafe manner - * this throws if you don't have a dag - */ + /** This memoizes using a hash map in a non-threadsafe manner this throws if + * you don't have a dag + */ def memoizeDagHashed[A, B](fn: (A, A => B) => B): A => B = { var cache = Map.empty[A, Option[B]] @@ -48,15 +49,14 @@ object Memoize { cache = cache.updated(a, Some(b)) b case Some(Some(b)) => b - case Some(None) => sys.error(s"loop found evaluating $a") + case Some(None) => sys.error(s"loop found evaluating $a") } } } - /** - * This memoizes using a hash map in a threadsafe manner - * it may loop forever and stack overflow if you don't have a DAG - */ + /** This memoizes using a hash map in a threadsafe manner it may loop forever + * and stack overflow if you don't have a DAG + */ def memoizeDagHashedConcurrent[A, B](fn: (A, A => B) => B): A => B = { val cache: ConcurrentHashMap[A, B] = new ConcurrentHashMap[A, B]() @@ -77,12 +77,14 @@ object Memoize { } } - /** - * This memoizes using a hash map in a threadsafe manner - * if the dependencies do not form a dag, you will deadlock - */ - def memoizeDagFuture[A, B](fn: (A, A => Par.F[B]) => Par.F[B]): A => Par.F[B] = { - val cache: ConcurrentHashMap[A, Par.P[B]] = new ConcurrentHashMap[A, Par.P[B]]() + /** This memoizes using a hash map in a threadsafe manner if the dependencies + * do not form a dag, you will deadlock + */ + def memoizeDagFuture[A, B]( + fn: (A, A => Par.F[B]) => Par.F[B] + ): A => Par.F[B] = { + val cache: ConcurrentHashMap[A, Par.P[B]] = + new ConcurrentHashMap[A, Par.P[B]]() new Function[A, Par.F[B]] { self => def apply(a: A) = { @@ -93,8 +95,7 @@ object Memoize { val resFut = fn(a, self) Par.complete(prom, resFut) resFut - } - else { + } else { // someone else is already working: Par.toF(prevProm) } diff --git a/core/src/main/scala/org/bykn/bosatsu/graph/Paths.scala b/core/src/main/scala/org/bykn/bosatsu/graph/Paths.scala index 13ef8bedf..5fdebfadf 100644 --- a/core/src/main/scala/org/bykn/bosatsu/graph/Paths.scala +++ b/core/src/main/scala/org/bykn/bosatsu/graph/Paths.scala @@ -3,29 +3,32 @@ package org.bykn.bosatsu.graph import cats.data.NonEmptyList object Paths { - /** - * A list of cycles all terminating at node - * E is intended to carry state about the edge in the graph - */ - def allCycles[A, E](node: A)(nfn: A => List[(E, A)]): List[NonEmptyList[(E, A)]] = + + /** A list of cycles all terminating at node E is intended to carry state + * about the edge in the graph + */ + def allCycles[A, E](node: A)( + nfn: A => List[(E, A)] + ): List[NonEmptyList[(E, A)]] = allPaths(node, node)(nfn) - /** - * A list of paths all terminating at to, but omitting from. - * E is intended to carry state about the edge in the graph - */ - def allPaths[A, E](from: A, to: A)(nfn: A => List[(E, A)]): List[NonEmptyList[(E, A)]] = { + /** A list of paths all terminating at to, but omitting from. E is intended to + * carry state about the edge in the graph + */ + def allPaths[A, E](from: A, to: A)( + nfn: A => List[(E, A)] + ): List[NonEmptyList[(E, A)]] = { def loop(from: A, to: A, avoid: Set[A]): List[NonEmptyList[(E, A)]] = { val newPaths = nfn(from).filterNot { case (_, a) => avoid(a) } val (ends, notEnds) = newPaths.partition { case (_, a) => a == to } - val rest = notEnds.flatMap { case edge@(_, a) => + val rest = notEnds.flatMap { case edge @ (_, a) => // don't loop back on a, loops to a are handled by ends loop(a, to, avoid + a).map(edge :: _) } NonEmptyList.fromList(ends) match { - case None => rest + case None => rest case Some(endsNE) => endsNE :: rest } } @@ -33,15 +36,13 @@ object Paths { loop(from, to, Set.empty) } - /** - * Same as allPaths but without the edge annotation type - */ + /** Same as allPaths but without the edge annotation type + */ def allPaths0[A](start: A, end: A)(nfn: A => List[A]): List[NonEmptyList[A]] = allPaths(start, end)(nfn.andThen(_.map(((), _)))).map(_.map(_._2)) - /** - * Same as allCycles but without the edge annotation type - */ + /** Same as allCycles but without the edge annotation type + */ def allCycle0[A](start: A)(nfn: A => List[A]): List[NonEmptyList[A]] = allPaths0(start, start)(nfn) } diff --git a/core/src/main/scala/org/bykn/bosatsu/graph/Toposort.scala b/core/src/main/scala/org/bykn/bosatsu/graph/Toposort.scala index ce091d008..40ab968dc 100644 --- a/core/src/main/scala/org/bykn/bosatsu/graph/Toposort.scala +++ b/core/src/main/scala/org/bykn/bosatsu/graph/Toposort.scala @@ -5,10 +5,9 @@ import cats.syntax.all._ object Toposort { - /** - * A result is the subdag in layers, - * as well as a set of loopNodes (a sorted list of nodes that don't form a dag) - */ + /** A result is the subdag in layers, as well as a set of loopNodes (a sorted + * list of nodes that don't form a dag) + */ sealed abstract class Result[A] { // these are the nodes which depend on a cyclic subgraph def loopNodes: List[A] @@ -18,7 +17,7 @@ object Toposort { def toSuccess: Option[Vector[NonEmptyList[A]]] = this match { case Success(res, _) => Some(res) - case Failure(_, _) => None + case Failure(_, _) => None } // true if each layer has exactly one item in it @@ -33,18 +32,23 @@ object Toposort { def isFailure: Boolean = !isSuccess } - final case class Success[A](layers: Vector[NonEmptyList[A]], nfn: A => List[A]) extends Result[A] { + final case class Success[A]( + layers: Vector[NonEmptyList[A]], + nfn: A => List[A] + ) extends Result[A] { def loopNodes = Nil } - final case class Failure[A](loopNodes: List[A], layers: Vector[NonEmptyList[A]]) extends Result[A] + final case class Failure[A]( + loopNodes: List[A], + layers: Vector[NonEmptyList[A]] + ) extends Result[A] - /** - * Build a deterministic topological sort - * of a graph. The items in the position i depend only - * on things at position i-1 or less. - * - * return a result which tells us the layers of the dag, and the non-dag nodes - */ + /** Build a deterministic topological sort of a graph. The items in the + * position i depend only on things at position i-1 or less. + * + * return a result which tells us the layers of the dag, and the non-dag + * nodes + */ def sort[A: Ordering](n: Iterable[A])(fn: A => List[A]): Result[A] = if (n.isEmpty) Success(Vector.empty, fn) else { @@ -57,13 +61,12 @@ object Toposort { nonEmpty.traverse(rec).map(_.max + 1) } } - val res = n - .toList + val res = n.toList // go through in a deterministic order .sorted .map { n => depth(n) match { - case None => Left(n) + case None => Left(n) case Some(d) => Right((d, n)) } } @@ -77,13 +80,12 @@ object Toposort { // we have to be bad if we aren't good bad = true Vector.empty - } - else { + } else { val len = goodIt.max + 1 val ary = Array.fill(len)(List.newBuilder[A]) res.foreach { case Right((idx, a)) => ary(idx) += a - case Left(_) => bad = true + case Left(_) => bad = true } // the items are already sorted since we added them in sorted order diff --git a/core/src/main/scala/org/bykn/bosatsu/graph/Tree.scala b/core/src/main/scala/org/bykn/bosatsu/graph/Tree.scala index de318341f..30e2cab6f 100644 --- a/core/src/main/scala/org/bykn/bosatsu/graph/Tree.scala +++ b/core/src/main/scala/org/bykn/bosatsu/graph/Tree.scala @@ -13,25 +13,31 @@ object Tree { val mapToTree: Map[A, Tree[A]] = toMap(t) - { (a: A) => mapToTree.get(a).fold(List.empty[A])(_.children.map(_.item)) } } - /** - * either return a tree representation of this dag or all cycles - * - * Note, this could run in a monadic context if we needed that: - * nfn: A => F[List[A]] for some monad F[_] - */ - def dagToTree[A](node: A)(nfn: A => List[A]): ValidatedNel[NonEmptyList[A], Tree[A]] = { - def treeOf(path: NonEmptyList[A], visited: Set[A]): ValidatedNel[NonEmptyList[A], Tree[A]] = { + /** either return a tree representation of this dag or all cycles + * + * Note, this could run in a monadic context if we needed that: nfn: A => + * F[List[A]] for some monad F[_] + */ + def dagToTree[A]( + node: A + )(nfn: A => List[A]): ValidatedNel[NonEmptyList[A], Tree[A]] = { + def treeOf( + path: NonEmptyList[A], + visited: Set[A] + ): ValidatedNel[NonEmptyList[A], Tree[A]] = { val children = nfn(path.head) - def assumeValid(children: List[A]): ValidatedNel[NonEmptyList[A], Tree[A]] = - children.traverse { a => - // we grow the path out here - treeOf(a :: path, visited + a) - } - .map(Tree(path.head, _)) + def assumeValid( + children: List[A] + ): ValidatedNel[NonEmptyList[A], Tree[A]] = + children + .traverse { a => + // we grow the path out here + treeOf(a :: path, visited + a) + } + .map(Tree(path.head, _)) NonEmptyList.fromList(children.filter(visited)) match { case Some(loops) => @@ -66,8 +72,7 @@ object Tree { def distinctBy[A, B](nel: List[A])(fn: A => B): List[A] = NonEmptyList.fromList(nel) match { - case None => Nil + case None => Nil case Some(nel) => distinctBy(nel)(fn).toList } } - diff --git a/core/src/main/scala/org/bykn/bosatsu/pattern/Matcher.scala b/core/src/main/scala/org/bykn/bosatsu/pattern/Matcher.scala index 3ee40af2a..5286cac16 100644 --- a/core/src/main/scala/org/bykn/bosatsu/pattern/Matcher.scala +++ b/core/src/main/scala/org/bykn/bosatsu/pattern/Matcher.scala @@ -13,7 +13,8 @@ trait Matcher[-P, -S, +R] { self => } object Matcher { - implicit class InvariantMatcher[P, S, R](val self: Matcher[P, S, R]) extends AnyVal { + implicit class InvariantMatcher[P, S, R](val self: Matcher[P, S, R]) + extends AnyVal { def mapWithInput[R1](fn: (S, R) => R1): Matcher[P, S, R1] = new Matcher[P, S, R1] { def apply(p: P): S => Option[R1] = { @@ -21,7 +22,7 @@ object Matcher { { (s: S) => next(s) match { - case None => None + case None => None case Some(r) => Some(fn(s, r)) } } @@ -33,15 +34,17 @@ object Matcher { def eqMatcher[A](implicit eqA: Eq[A]): Matcher[A, A, Unit] = new Matcher[A, A, Unit] { - def apply(a: A): A => Option[Unit] = - { (s: A) => if (eqA.eqv(a, s)) someUnit else None } + def apply(a: A): A => Option[Unit] = { (s: A) => + if (eqA.eqv(a, s)) someUnit else None + } } - val charMatcher: Matcher[Char, Char, Unit] = eqMatcher(Eq.fromUniversalEquals[Char]) + val charMatcher: Matcher[Char, Char, Unit] = eqMatcher( + Eq.fromUniversalEquals[Char] + ) def fnMatch[A]: Matcher[A => Boolean, A, Unit] = new Matcher[A => Boolean, A, Unit] { - def apply(p: A => Boolean) = - { (a: A) => if (p(a)) someUnit else None } + def apply(p: A => Boolean) = { (a: A) => if (p(a)) someUnit else None } } } diff --git a/core/src/main/scala/org/bykn/bosatsu/pattern/NamedSeqPattern.scala b/core/src/main/scala/org/bykn/bosatsu/pattern/NamedSeqPattern.scala index c58448fde..8ca9f4d64 100644 --- a/core/src/main/scala/org/bykn/bosatsu/pattern/NamedSeqPattern.scala +++ b/core/src/main/scala/org/bykn/bosatsu/pattern/NamedSeqPattern.scala @@ -10,7 +10,7 @@ sealed trait NamedSeqPattern[+A] { def loop(n: NamedSeqPattern[A], right: List[SeqPart[A]]): List[SeqPart[A]] = n match { case Bind(_, n) => loop(n, right) - case NEmpty => right + case NEmpty => right case NCat(first, second) => val r2 = loop(second, right) loop(first, r2) @@ -29,10 +29,10 @@ sealed trait NamedSeqPattern[+A] { // we are renderable if all Wild/AnyElem are named def isRenderable: Boolean = this match { - case NEmpty => true - case Bind(_, _) => true + case NEmpty => true + case Bind(_, _) => true case NSeqPart(Lit(_)) => true - case NSeqPart(_) => false + case NSeqPart(_) => false case NCat(l, r) => l.isRenderable && r.isRenderable } @@ -42,14 +42,15 @@ sealed trait NamedSeqPattern[+A] { def loop(n: NamedSeqPattern[A], right: S): Option[S] = n match { - case NEmpty => Some(right) + case NEmpty => Some(right) case Bind(nm, r) => // since we have this name, we don't need to recurse - names.get(nm) + names + .get(nm) .map { seq => ms.combine(seq, right) } .orElse(loop(r, right)) case NSeqPart(SeqPart.Lit(c)) => Some(ms.combine(fn(c), right)) - case NSeqPart(_) => None + case NSeqPart(_) => None case NCat(l, r) => loop(r, right) .flatMap { right => @@ -62,9 +63,9 @@ sealed trait NamedSeqPattern[+A] { def names: List[String] = this match { - case Bind(name, nsp) => name :: nsp.names + case Bind(name, nsp) => name :: nsp.names case NEmpty | NSeqPart(_) => Nil - case NCat(h, t) => h.names ::: t.names + case NCat(h, t) => h.names ::: t.names } } @@ -77,15 +78,19 @@ object NamedSeqPattern { val Wild: NamedSeqPattern[Nothing] = NSeqPart(SeqPart.Wildcard) val Any: NamedSeqPattern[Nothing] = NSeqPart(SeqPart.AnyElem) - case class Bind[A](name: String, p: NamedSeqPattern[A]) extends NamedSeqPattern[A] + case class Bind[A](name: String, p: NamedSeqPattern[A]) + extends NamedSeqPattern[A] case object NEmpty extends NamedSeqPattern[Nothing] case class NSeqPart[A](part: SeqPart[A]) extends NamedSeqPattern[A] - case class NCat[A](first: NamedSeqPattern[A], second: NamedSeqPattern[A]) extends NamedSeqPattern[A] + case class NCat[A](first: NamedSeqPattern[A], second: NamedSeqPattern[A]) + extends NamedSeqPattern[A] def fromLit[A](a: A): NamedSeqPattern[A] = NSeqPart(SeqPart.Lit(a)) - def matcher[E, I, S, R](split: Splitter[E, I, S, R]): Matcher[NamedSeqPattern[E], S, (R, Map[String, S])] = + def matcher[E, I, S, R]( + split: Splitter[E, I, S, R] + ): Matcher[NamedSeqPattern[E], S, (R, Map[String, S])] = new Matcher[NamedSeqPattern[E], S, (R, Map[String, S])] { def apply(nsp: NamedSeqPattern[E]): S => Option[(R, Map[String, S])] = { val machine = Impl.toMachine(nsp, Nil) @@ -95,7 +100,10 @@ object NamedSeqPattern { } private[this] object Impl { - def toMachine[A](n: NamedSeqPattern[A], right: List[Machine[A]]): List[Machine[A]] = + def toMachine[A]( + n: NamedSeqPattern[A], + right: List[Machine[A]] + ): List[Machine[A]] = n match { case NEmpty => right case Bind(name, n) => @@ -112,31 +120,32 @@ object NamedSeqPattern { def hasWildLeft(m: List[Machine[Any]]): Boolean = m match { - case Nil => false + case Nil => false case MSeqPart(SeqPart.Wildcard) :: _ => true - case MSeqPart(_) :: _ => false - case _ :: tail => hasWildLeft(tail) + case MSeqPart(_) :: _ => false + case _ :: tail => hasWildLeft(tail) } import SeqPart.{AnyElem, Lit, SeqPart1, Wildcard} - def capture[S](empty: S, capturing: List[String], res: Map[String, S])(fn: S => S): Map[String, S] = + def capture[S](empty: S, capturing: List[String], res: Map[String, S])( + fn: S => S + ): Map[String, S] = capturing.foldLeft(res) { (mapB, n) => val right = mapB.get(n) match { - case None => empty + case None => empty case Some(bv) => bv } mapB.updated(n, fn(right)) } def matches[E, I, S, R]( - split: Splitter[E, I, S, R], - m: List[Machine[E]], - capturing: List[String]): S => Option[(R, Map[String, S])] = - + split: Splitter[E, I, S, R], + m: List[Machine[E]], + capturing: List[String] + ): S => Option[(R, Map[String, S])] = m match { case Nil => - val res = Some((split.monoidResult.empty, Map.empty[String, S])) { (str: S) => @@ -150,7 +159,7 @@ object NamedSeqPattern { case Nil => // $COVERAGE-OFF$ sys.error("illegal End with no capturing") - // $COVERAGE-ON$ + // $COVERAGE-ON$ case n :: cap => // if n captured nothing, we need // to add an empty list @@ -169,13 +178,11 @@ object NamedSeqPattern { if (hasWildLeft(tail)) { // two adjacent wilds means this one matches nothing matches(split, tail, capturing) - } - else { + } else { val me = matchEnd(split, tail, capturing) me.andThen { stream => - stream - .headOption + stream.headOption .map { case (prefix, (rightR, rightBind)) => // now merge the prefix result val resMatched = capturing.foldLeft(rightBind) { (st, n) => @@ -195,13 +202,11 @@ object NamedSeqPattern { val headm: I => Option[R] = p1 match { case AnyElem => { (_: I) => someEmpty } - case Lit(c) => split.matcher(c) + case Lit(c) => split.matcher(c) } val tailm: S => Option[(R, Map[String, S])] = - matches(split, - tail, - capturing) + matches(split, tail, capturing) { (str: S) => for { @@ -210,14 +215,18 @@ object NamedSeqPattern { rh <- headm(h) rt <- tailm(t) (tailr, tailm) = rt - } yield (split.monoidResult.combine(rh, tailr), capture(split.emptySeq, capturing, tailm)(split.cons(h, _))) + } yield ( + split.monoidResult.combine(rh, tailr), + capture(split.emptySeq, capturing, tailm)(split.cons(h, _)) + ) } } def matchEnd[E, I, S, R]( - split: Splitter[E, I, S, R], - m: List[Machine[E]], - capturing: List[String]): S => LazyList[(S, (R, Map[String, S]))] = + split: Splitter[E, I, S, R], + m: List[Machine[E]], + capturing: List[String] + ): S => LazyList[(S, (R, Map[String, S]))] = m match { case Nil => // we always match the end @@ -233,7 +242,7 @@ object NamedSeqPattern { case Nil => // $COVERAGE-OFF$ sys.error("illegal End with no capturing") - // $COVERAGE-ON$ + // $COVERAGE-ON$ case n :: cap => // if n captured nothing, we need // to add an empty list @@ -254,24 +263,24 @@ object NamedSeqPattern { val mtail = matches(split, tail, capturing) val splits = p1 match { - case Lit(c) => split.positions(c) + case Lit(c) => split.positions(c) case AnyElem => split.anySplits(_: S) } - { (s: S) => - splits(s).map { case (pre, i, r, post) => - mtail(post) - .map { case (rp, mapRes) => - val res1 = split.monoidResult.combine(r, rp) - val res2 = capture(split.emptySeq, capturing, mapRes)(split.cons(i, _)) - (pre, (res1, res2)) - } - } - .collect { case Some(res) => res } + splits(s) + .map { case (pre, i, r, post) => + mtail(post) + .map { case (rp, mapRes) => + val res1 = split.monoidResult.combine(r, rp) + val res2 = capture(split.emptySeq, capturing, mapRes)( + split.cons(i, _) + ) + (pre, (res1, res2)) + } + } + .collect { case Some(res) => res } } } } } - - diff --git a/core/src/main/scala/org/bykn/bosatsu/pattern/SeqPart.scala b/core/src/main/scala/org/bykn/bosatsu/pattern/SeqPart.scala index 85a0eea70..e20e8f59a 100644 --- a/core/src/main/scala/org/bykn/bosatsu/pattern/SeqPart.scala +++ b/core/src/main/scala/org/bykn/bosatsu/pattern/SeqPart.scala @@ -9,17 +9,19 @@ object SeqPart { override def notWild: Boolean = true } - implicit def partOrdering[E](implicit elemOrdering: Ordering[E]): Ordering[SeqPart[E]] = + implicit def partOrdering[E](implicit + elemOrdering: Ordering[E] + ): Ordering[SeqPart[E]] = new Ordering[SeqPart[E]] { def compare(a: SeqPart[E], b: SeqPart[E]) = (a, b) match { case (Lit(i1), Lit(i2)) => elemOrdering.compare(i1, i2) - case (Lit(_), _) => -1 - case (_, Lit(_)) => 1 - case (AnyElem, AnyElem) => 0 - case (AnyElem, Wildcard) => -1 - case (Wildcard, AnyElem) => 1 + case (Lit(_), _) => -1 + case (_, Lit(_)) => 1 + case (AnyElem, AnyElem) => 0 + case (AnyElem, Wildcard) => -1 + case (Wildcard, AnyElem) => 1 case (Wildcard, Wildcard) => 0 } } @@ -29,7 +31,9 @@ object SeqPart { // 0 or more characters case object Wildcard extends SeqPart[Nothing] - implicit def part1SetOps[A](implicit setOpsA: SetOps[A]): SetOps[SeqPart1[A]] = + implicit def part1SetOps[A](implicit + setOpsA: SetOps[A] + ): SetOps[SeqPart1[A]] = new SetOps[SeqPart1[A]] { private val anyList = AnyElem :: Nil @@ -40,7 +44,7 @@ object SeqPart { def anyDiff(a: A) = setOpsA.top match { - case None => anyList + case None => anyList case Some(topA) => setOpsA.difference(topA, a).map(toPart1) } @@ -48,13 +52,14 @@ object SeqPart { def isTop(c: SeqPart1[A]) = c match { case AnyElem => true - case Lit(a) => setOpsA.isTop(a) + case Lit(a) => setOpsA.isTop(a) } def intersection(p1: SeqPart1[A], p2: SeqPart1[A]): List[SeqPart1[A]] = (p1, p2) match { case (Lit(c1), Lit(c2)) => - setOpsA.intersection(c1, c2) + setOpsA + .intersection(c1, c2) .map(toPart1(_)) case (AnyElem, _) => if (isTop(p2)) AnyElem :: Nil @@ -92,18 +97,15 @@ object SeqPart { def litOpt(u: List[SeqPart1[A]], acc: List[A]): Option[List[Lit[A]]] = u match { case Nil => Some(setOpsA.unifyUnion(acc.reverse).map(Lit(_))) - case AnyElem :: _ => None + case AnyElem :: _ => None case Lit(a) :: _ if setOpsA.isTop(a) => None - case Lit(a) :: tail => litOpt(tail, a :: acc) + case Lit(a) :: tail => litOpt(tail, a :: acc) } - litOpt(u, Nil) match { - case None => AnyElem :: Nil + case None => AnyElem :: Nil case Some(u) => u } } } } - - diff --git a/core/src/main/scala/org/bykn/bosatsu/pattern/SeqPattern.scala b/core/src/main/scala/org/bykn/bosatsu/pattern/SeqPattern.scala index d9a827b3f..ab566305a 100644 --- a/core/src/main/scala/org/bykn/bosatsu/pattern/SeqPattern.scala +++ b/core/src/main/scala/org/bykn/bosatsu/pattern/SeqPattern.scala @@ -8,23 +8,22 @@ sealed trait SeqPattern[+A] { def matchesAny: Boolean = this match { - case Empty => false + case Empty => false case Cat(Wildcard, t) => t.matchesEmpty - case Cat(_, _) => false + case Cat(_, _) => false } def matchesEmpty: Boolean = this match { - case Empty => true + case Empty => true case Cat(Wildcard, t) => t.matchesEmpty - case Cat(_, _) => false + case Cat(_, _) => false } def isEmpty: Boolean = this == Empty - /** - * Concat that SeqPattern on the right - */ + /** Concat that SeqPattern on the right + */ def +[A1 >: A](that: SeqPattern[A1]): SeqPattern[A1] = SeqPattern.fromList(toList ::: that.toList) @@ -33,14 +32,14 @@ sealed trait SeqPattern[+A] { def prependWild: SeqPattern[A] = this match { - case Cat(AnyElem, t) => Cat(AnyElem, Cat(Wildcard, t)) + case Cat(AnyElem, t) => Cat(AnyElem, Cat(Wildcard, t)) case Cat(Wildcard, _) => this - case notAlreadyWild => Cat(Wildcard, notAlreadyWild) + case notAlreadyWild => Cat(Wildcard, notAlreadyWild) } def toList: List[SeqPart[A]] = this match { - case Empty => Nil + case Empty => Nil case Cat(h, t) => h :: t.toList } @@ -53,20 +52,19 @@ sealed trait SeqPattern[+A] { case Cat(_, _) => None } - /** - * If two wilds are adjacent, the left one will always match empty string - * this normalize just removes the left wild - * - * combine adjacent strings - */ + /** If two wilds are adjacent, the left one will always match empty string + * this normalize just removes the left wild + * + * combine adjacent strings + */ def normalize: SeqPattern[A] = this match { - case Empty => Empty + case Empty => Empty case Cat(Wildcard, Cat(AnyElem, t)) => // move AnyElem out val wtn = Cat(Wildcard, t).normalize Cat(AnyElem, wtn) - case Cat(Wildcard, tail@Cat(Wildcard, _)) => + case Cat(Wildcard, tail @ Cat(Wildcard, _)) => // remove duplicate Wildcard tail.normalize case Cat(h, tail) => @@ -74,28 +72,26 @@ sealed trait SeqPattern[+A] { } def show: String = - toList - .iterator - .map { - case Lit('.') => "\\." - case Lit('*') => "\\*" - case Lit(c) => c.toString - case AnyElem => "." - case Wildcard => "*" - } - .mkString + toList.iterator.map { + case Lit('.') => "\\." + case Lit('*') => "\\*" + case Lit(c) => c.toString + case AnyElem => "." + case Wildcard => "*" + }.mkString } object SeqPattern { case object Empty extends SeqPattern[Nothing] - case class Cat[A](head: SeqPart[A], tail: SeqPattern[A]) extends SeqPattern[A] { + case class Cat[A](head: SeqPart[A], tail: SeqPattern[A]) + extends SeqPattern[A] { // return the last non-empty @annotation.tailrec final def rightMost: SeqPart[A] = tail match { - case Empty => head - case Cat(h, Empty) => h - case Cat(_, r@Cat(_, _)) => r.rightMost + case Empty => head + case Cat(h, Empty) => h + case Cat(_, r @ Cat(_, _)) => r.rightMost } def reverseCat: Cat[A] = { @@ -119,7 +115,7 @@ object SeqPattern { val ordSeqPart: Ordering[SeqPart[A]] = implicitly[Ordering[SeqPart[A]]] def compare(a: SeqPattern[A], b: SeqPattern[A]) = (a, b) match { - case (Empty, Empty) => 0 + case (Empty, Empty) => 0 case (Empty, Cat(_, _)) => -1 case (Cat(_, _), Empty) => 1 case (Cat(h1, t1), Cat(h2, t2)) => @@ -129,7 +125,10 @@ object SeqPattern { } } - implicit def seqPatternSetOps[A](implicit part1SetOps: SetOps[SeqPart.SeqPart1[A]], ordA: Ordering[A]): SetOps[SeqPattern[A]] = + implicit def seqPatternSetOps[A](implicit + part1SetOps: SetOps[SeqPart.SeqPart1[A]], + ordA: Ordering[A] + ): SetOps[SeqPattern[A]] = new SetOps[SeqPattern[A]] { import SeqPart.{SeqPart1, AnyElem, Wildcard} @@ -145,18 +144,20 @@ object SeqPattern { // this is an incomplete heuristic now, not a complete solution def unifyUnion(union: List[SeqPattern[A]]): List[SeqPattern[A]] = unifyUnionList { - union - .map(_.normalize) - .distinct - .map(_.toList) - } + union + .map(_.normalize) + .distinct + .map(_.toList) + } .map(SeqPattern.fromList(_).normalize) .sorted private[this] val someWild = Some(Wildcard :: Nil) private[this] val someNil = Some(Nil) - private def unifyUnionList(union: List[List[SeqPart[A]]]): List[List[SeqPart[A]]] = { + private def unifyUnionList( + union: List[List[SeqPart[A]]] + ): List[List[SeqPart[A]]] = { // if a part of Sequences are the same except this part, can we merge by appending // something? @@ -164,36 +165,35 @@ object SeqPattern { list match { case (a: SeqPart1[A]) :: Wildcard :: Nil if isAny(a) => someWild case Wildcard :: (a: SeqPart1[A]) :: Nil if isAny(a) => someWild - case Wildcard :: Wildcard :: Nil => someWild - case Wildcard :: Nil => someWild - case Nil => someNil - case _ => None + case Wildcard :: Wildcard :: Nil => someWild + case Wildcard :: Nil => someWild + case Nil => someNil + case _ => None } - def unifyPair(left: List[SeqPart[A]], right: List[SeqPart[A]]): Option[List[SeqPart[A]]] = { + def unifyPair( + left: List[SeqPart[A]], + right: List[SeqPart[A]] + ): Option[List[SeqPart[A]]] = { def o1 = if (left.startsWith(right)) { unifySeqPart(left.drop(right.size)).map(right ::: _) - } - else None + } else None def o2 = if (right.startsWith(left)) { unifySeqPart(right.drop(left.size)).map(left ::: _) - } - else None + } else None def o3 = if (left.endsWith(right)) { unifySeqPart(left.dropRight(right.size)).map(_ ::: right) - } - else None + } else None def o4 = if (right.endsWith(left)) { unifySeqPart(right.dropRight(left.size)).map(_ ::: left) - } - else None + } else None def o5 = if (subsetList(left, right)) Some(right) else None def o6 = if (subsetList(right, left)) Some(left) else None @@ -222,13 +222,11 @@ object SeqPattern { val rest = items.iterator.filterNot(_ == null).toList // let's look again unifyUnionList(pair :: rest) - } - else union + } else union } - /** - * return true if p1 <= p2, can give false negatives - */ + /** return true if p1 <= p2, can give false negatives + */ def subset(p1: SeqPattern[A], p2: SeqPattern[A]): Boolean = p2.matchesAny || { // if p2 doesn't matchEmpty but p1 does, we are done @@ -241,9 +239,12 @@ object SeqPattern { final def isAny(p: SeqPart1[A]): Boolean = part1SetOps.isTop(p) - private def subsetList(p1: List[SeqPart[A]], p2: List[SeqPart[A]]): Boolean = + private def subsetList( + p1: List[SeqPart[A]], + p2: List[SeqPart[A]] + ): Boolean = (p1, p2) match { - case (Nil, Nil) => true + case (Nil, Nil) => true case (Nil, (_: SeqPart1[A]) :: _) => false case (Nil, Wildcard :: t) => subsetList(Nil, t) @@ -251,15 +252,15 @@ object SeqPattern { case ((h1: SeqPart1[A]) :: t1, (h2: SeqPart1[A]) :: t2) => part1SetOps.subset(h1, h2) && subsetList(t1, t2) case (Wildcard :: Wildcard :: t1, _) => - // normalize the left: - subsetList(Wildcard :: t1, p2) + // normalize the left: + subsetList(Wildcard :: t1, p2) case (_, Wildcard :: Wildcard :: t2) => - // normalize the right: - subsetList(p1, Wildcard :: t2) + // normalize the right: + subsetList(p1, Wildcard :: t2) case (_, Wildcard :: (a2: SeqPart1[A]) :: t2) if isAny(a2) => - // we know that right can't match empty, - // let's see if that helps us rule out matches on the left - subsetList(p1, AnyElem :: Wildcard :: t2) + // we know that right can't match empty, + // let's see if that helps us rule out matches on the left + subsetList(p1, AnyElem :: Wildcard :: t2) // either t1 or t2 also ends with Wildcard case (_ :: _, Wildcard :: _) if p2.last.notWild => // wild on the right but not at the end @@ -275,8 +276,8 @@ object SeqPattern { // p1 = *t1 = t1 + _:p1 // _:p1 <= h2:t2 => (_ <= h2) && (p1 <= t2) isAny(h2) && - subsetList(t1, p2) && - subsetList(p1, t2) + subsetList(t1, p2) && + subsetList(p1, t2) case ((_: SeqPart1[A]) :: t1, Wildcard :: t2) => // we could pop off one wildcard to match head // or we could match with nothing but the rest @@ -293,10 +294,12 @@ object SeqPattern { subsetList(t1, p2) } - /** - * Compute a list of patterns that matches both patterns exactly - */ - def intersection(p1: SeqPattern[A], p2: SeqPattern[A]): List[SeqPattern[A]] = + /** Compute a list of patterns that matches both patterns exactly + */ + def intersection( + p1: SeqPattern[A], + p2: SeqPattern[A] + ): List[SeqPattern[A]] = (p1, p2) match { case (Empty, _) => if (p2.matchesEmpty) p1 :: Nil else Nil @@ -305,12 +308,12 @@ object SeqPattern { case (_, Cat(Wildcard, _)) if p2.matchesAny => // matches anything p1 :: Nil - case (_, _) if subset(p1, p2) => p1 :: Nil - case (_, _) if subset(p2, p1) => p2 :: Nil - case (Cat(Wildcard, t1@Cat(Wildcard, _)), _) => + case (_, _) if subset(p1, p2) => p1 :: Nil + case (_, _) if subset(p2, p1) => p2 :: Nil + case (Cat(Wildcard, t1 @ Cat(Wildcard, _)), _) => // unnormalized intersection(t1, p2) - case (_, Cat(Wildcard, t2@Cat(Wildcard, _))) => + case (_, Cat(Wildcard, t2 @ Cat(Wildcard, _))) => // unnormalized intersection(p1, t2) case (Cat(Wildcard, Cat(a1: SeqPart1[A], t1)), _) if isAny(a1) => @@ -319,7 +322,8 @@ object SeqPattern { case (_, Cat(Wildcard, Cat(a2: SeqPart1[A], t2))) if isAny(a2) => // *. == .*, push Wildcards to the end intersection(p1, Cat(AnyElem, Cat(Wildcard, t2))) - case (c1@Cat(Wildcard, _), c2@Cat(Wildcard, _)) if c1.rightMost.notWild || c2.rightMost.notWild => + case (c1 @ Cat(Wildcard, _), c2 @ Cat(Wildcard, _)) + if c1.rightMost.notWild || c2.rightMost.notWild => // let's avoid the most complex case of both having // wild on the front if possible intersection(c1.reverse, c2.reverse).map(_.reverse) @@ -354,50 +358,45 @@ object SeqPattern { unifyUnion(intr) } - /** - * return the patterns that match p1 but not p2 - * - * For fixed sets A, B if we have (A1 x B1) - (A2 x B2) = - * A1 = (A1 n A2) u (A1 - A2) - * A2 = (A1 n A2) u (A2 - A1) - * so we can decompose: - * - * A1 x B1 = (A1 n A2)xB1 u (A1 - A2)xB1 - * A2 x B2 = (A1 n A2)xB2 u (A2 - A1)xB2 - * - * the difference is: - * (A1 n A2)x(B1 - B2) u (A1 - A2)xB1 - * - * A - (B1 u B2) = (A - B1) n (A - B2) - * A - (B1 u B2) <= ((A - B1) u (A - B2)) - (B1 n B2) - * A - (B1 u B2) >= (A - B1) u (A - B2) - * - * so if (B1 n B2) = 0, then: - * A - (B1 u B2) = (A - B1) u (A - B2) - * - * (A1 u A2) - B = (A1 - B) u (A2 - B) - * - * The last challenge is we need to operate on - * s ingle characters, so we need to expand - * wild into [*] = [] | [_, *], since our pattern - * language doesn't have a symbol for - * a single character match we have to be a bit more careful - * - * also, we can't exactly represent Wildcard - Lit - * so this is actually an upperbound on the difference - * which is to say, all the returned patterns match p1, - * but some of them also match p2 - */ - def difference(p1: SeqPattern[A], p2: SeqPattern[A]): List[SeqPattern[A]] = + /** return the patterns that match p1 but not p2 + * + * For fixed sets A, B if we have (A1 x B1) - (A2 x B2) = A1 = (A1 n A2) + * u (A1 - A2) A2 = (A1 n A2) u (A2 - A1) so we can decompose: + * + * A1 x B1 = (A1 n A2)xB1 u (A1 - A2)xB1 A2 x B2 = (A1 n A2)xB2 u (A2 - + * A1)xB2 + * + * the difference is: (A1 n A2)x(B1 - B2) u (A1 - A2)xB1 + * + * A - (B1 u B2) = (A - B1) n (A - B2) A - (B1 u B2) <= ((A - B1) u (A - + * B2)) - (B1 n B2) A - (B1 u B2) >= (A - B1) u (A - B2) + * + * so if (B1 n B2) = 0, then: A - (B1 u B2) = (A - B1) u (A - B2) + * + * (A1 u A2) - B = (A1 - B) u (A2 - B) + * + * The last challenge is we need to operate on s ingle characters, so we + * need to expand wild into [*] = [] | [_, *], since our pattern language + * doesn't have a symbol for a single character match we have to be a bit + * more careful + * + * also, we can't exactly represent Wildcard - Lit so this is actually an + * upperbound on the difference which is to say, all the returned + * patterns match p1, but some of them also match p2 + */ + def difference( + p1: SeqPattern[A], + p2: SeqPattern[A] + ): List[SeqPattern[A]] = (p1, p2) match { case (Empty, _) => if (p2.matchesEmpty) Nil else p1 :: Nil case (Cat(_: SeqPart1[A], _), Empty) => // Cat(SeqPart1[A], _) does not match Empty p1 :: Nil - case (Cat(Wildcard, t1@Cat(Wildcard, _)), _) => + case (Cat(Wildcard, t1 @ Cat(Wildcard, _)), _) => // unnormalized difference(t1, p2) - case (_, Cat(Wildcard, t2@Cat(Wildcard, _))) => + case (_, Cat(Wildcard, t2 @ Cat(Wildcard, _))) => // unnormalized difference(p1, t2) case (Cat(Wildcard, Cat(a1: SeqPart1[A], t1)), _) if isAny(a1) => @@ -425,11 +424,9 @@ object SeqPattern { if (intH.isEmpty) { // then h1 - h2 = h1 p1 :: Nil - } - else if (disjoint(t1, t2)) { + } else if (disjoint(t1, t2)) { p1 :: Nil - } - else { + } else { val d1 = for { h <- intH @@ -452,16 +449,15 @@ object SeqPattern { // *:t1 - (h2:t2) = t1 + _:p1 - h2:t2 // = (t1 - p2) + (_:p1 - h2:t2) val d12 = { - //(_:p1 - h2:t2) = - //(_ n h2):(p1 - t2) + (_ - h2):p1 - //h2:(p1 - t2) + (_ - h2):p1 + // (_:p1 - h2:t2) = + // (_ n h2):(p1 - t2) + (_ - h2):p1 + // h2:(p1 - t2) + (_ - h2):p1 // - //or: - //(_ - h2):(p1 n t2) + _:(p1 - t2) + // or: + // (_ - h2):(p1 n t2) + _:(p1 - t2) if (disjoint(p1, t2)) { Cat(AnyElem, p1) :: Nil - } - else { + } else { val dtail = difference(p1, t2) val d1 = dtail.map(Cat(h2, _)) val d2 = part1SetOps.difference(AnyElem, h2).map(Cat(_, p1)) @@ -470,13 +466,12 @@ object SeqPattern { } val d3 = difference(t1, p2) unifyUnion(d12 ::: d3) - case (c1@Cat(Wildcard, t1), c2@Cat(Wildcard, t2)) => + case (c1 @ Cat(Wildcard, t1), c2 @ Cat(Wildcard, t2)) => if (c1.rightMost.notWild || c2.rightMost.notWild) { // let's avoid the most complex case of both having // wild on the front if possible difference(c1.reverse, c2.reverse).map(_.reverse) - } - else { + } else { // both start and end with wildcard // // p1 - (t2 + _:p2) = @@ -520,8 +515,7 @@ object SeqPattern { val as = difference(p1, t2) if (t1.isEmpty) { as - } - else { + } else { // if x <= *:(a n b) and a then it is <= a n (*:(a n b)) val bs = difference(t1, Cat(AnyElem, p2)) // (a1 + a2) n (b1 + b2) = @@ -542,7 +536,9 @@ object SeqPattern { } } - def matcher[A, I, S, R](split: Splitter[A, I, S, R]): Matcher[SeqPattern[A], S, R] = + def matcher[A, I, S, R]( + split: Splitter[A, I, S, R] + ): Matcher[SeqPattern[A], S, R] = new Matcher[SeqPattern[A], S, R] { import SeqPart.{AnyElem, Lit, SeqPart1, Wildcard} @@ -561,7 +557,8 @@ object SeqPattern { (h, t) = ht rh <- mh(h) rt <- mt(t) - } yield split.monoidResult.combine(rh, rt) } + } yield split.monoidResult.combine(rh, rt) + } case Cat(AnyElem, t) => val mt = apply(t) @@ -570,26 +567,31 @@ object SeqPattern { ht <- split.uncons(s) (_, t) = ht rt <- mt(t) - } yield rt } + } yield rt + } case Cat(Wildcard, t) => matchEnd(t).andThen(_.headOption.map(_._2)) } def matchEnd(p: SeqPattern[A]): S => LazyList[(S, R)] = p match { - case Empty => { (s: S) => (s, split.monoidResult.empty) #:: LazyList.empty } + case Empty => { (s: S) => + (s, split.monoidResult.empty) #:: LazyList.empty + } case Cat(p: SeqPart1[A], t) => val splitFn: S => LazyList[(S, I, R, S)] = p match { - case Lit(c) => split.positions(c) + case Lit(c) => split.positions(c) case AnyElem => split.anySplits(_: S) } val tailMatch = apply(t) { (s: S) => splitFn(s) - .map { case (pre, _, r, post) => + .map { case (pre, _, r, post) => tailMatch(post) - .map { rtail => (pre, split.monoidResult.combine(r, rtail)) } + .map { rtail => + (pre, split.monoidResult.combine(r, rtail)) + } } .collect { case Some(res) => res } } diff --git a/core/src/main/scala/org/bykn/bosatsu/pattern/SetOps.scala b/core/src/main/scala/org/bykn/bosatsu/pattern/SetOps.scala index 366b44d32..f5cd0c532 100644 --- a/core/src/main/scala/org/bykn/bosatsu/pattern/SetOps.scala +++ b/core/src/main/scala/org/bykn/bosatsu/pattern/SetOps.scala @@ -1,74 +1,59 @@ package org.bykn.bosatsu.pattern -/** - * These are set operations we can do on patterns - */ +/** These are set operations we can do on patterns + */ trait SetOps[A] { - /** - * a representation of the set with everything in it - * not all sets have upper bounds we can represent - */ + /** a representation of the set with everything in it not all sets have upper + * bounds we can represent + */ def top: Option[A] - /** - * if everything is <= A, maybe more than one representation of top - */ + /** if everything is <= A, maybe more than one representation of top + */ def isTop(a: A): Boolean - /** - * intersect two values and return a union represented as a list - */ + /** intersect two values and return a union represented as a list + */ def intersection(a1: A, a2: A): List[A] - /** - * Return true if a1 and a2 are disjoint - */ + /** Return true if a1 and a2 are disjoint + */ def disjoint(a1: A, a2: A): Boolean = intersection(a1, a2).isEmpty - /** - * remove a2 from a1 return a union represented as a list - * - * this should be the tightest upperbound we can find - */ + /** remove a2 from a1 return a union represented as a list + * + * this should be the tightest upperbound we can find + */ def difference(a1: A, a2: A): List[A] - /** - * This should unify the union into the fewest number - * of patterns without changing the meaning of the union - */ + /** This should unify the union into the fewest number of patterns without + * changing the meaning of the union + */ def unifyUnion(u: List[A]): List[A] - /** - * if true, all elements in a are in b, - * if false, there is no promise - * - * this should be a reasonable cheap operation - * that is allowed to say no in order - * to avoid very expensive work - */ + /** if true, all elements in a are in b, if false, there is no promise + * + * this should be a reasonable cheap operation that is allowed to say no in + * order to avoid very expensive work + */ def subset(a: A, b: A): Boolean - /** - * Remove all items in p2 from all items in p1 - * and unify the remaining union - */ + /** Remove all items in p2 from all items in p1 and unify the remaining union + */ def differenceAll(p1: List[A], p2: List[A]): List[A] = p2.foldLeft(p1) { (p1s, p) => // remove p from all of p1s p1s.flatMap(difference(_, p)) } - /** - * if top is defined - * a list of matches that would make the current set of matches total - * - * Note, a law here is that: - * missingBranches(te, t, branches).flatMap { ms => - * assert(missingBranches(te, t, branches ::: ms).isEmpty) - * } - */ + /** if top is defined a list of matches that would make the current set of + * matches total + * + * Note, a law here is that: missingBranches(te, t, branches).flatMap { ms => + * assert(missingBranches(te, t, branches ::: ms).isEmpty) } + */ def missingBranches(top: List[A], branches: List[A]): List[A] = { // we can subtract in any order // since a - b - c = a - c - b @@ -78,7 +63,9 @@ trait SetOps[A] { // 3! = 6 val lookahead = 3 - val missing = SetOps.greedySearch(lookahead, top, unifyUnion(branches))(differenceAll(_, _))(_.size) + val missing = SetOps.greedySearch(lookahead, top, unifyUnion(branches))( + differenceAll(_, _) + )(_.size) // filter any unreachable, which can happen when earlier items shadow later // ones @@ -86,10 +73,9 @@ trait SetOps[A] { missing.filterNot(unreach.toSet) } - /** - * if we match these branches in order, which of them - * are completely covered by previous matches - */ + /** if we match these branches in order, which of them are completely covered + * by previous matches + */ def unreachableBranches(branches: List[A]): List[A] = { def withPrev(bs: List[A], prev: List[A]): List[(A, List[A])] = bs match { @@ -99,9 +85,10 @@ trait SetOps[A] { } withPrev(branches, Nil) - .collect { case (p, prev) if differenceAll(p :: Nil, prev).isEmpty => - // if there is nothing, this is unreachable - p + .collect { + case (p, prev) if differenceAll(p :: Nil, prev).isEmpty => + // if there is nothing, this is unreachable + p } } } @@ -124,7 +111,9 @@ object SetOps { } // we search for the best order to apply the diffs that minimizes the score - def greedySearch[A, B, C: Ordering](lookahead: Int, union: A, diffs: List[B])(fn: (A, List[B]) => A)(score: A => C): A = + def greedySearch[A, B, C: Ordering](lookahead: Int, union: A, diffs: List[B])( + fn: (A, List[B]) => A + )(score: A => C): A = diffs match { case Nil => union case _ => @@ -149,7 +138,6 @@ object SetOps { greedySearch(lookahead, u1, diffs.filterNot(_ == best))(fn)(score) } - def distinct[A](implicit ordA: Ordering[A]): SetOps[A] = new SetOps[A] { def top: Option[A] = None @@ -167,7 +155,7 @@ object SetOps { def nub(u: List[A]): List[A] = u match { case Nil | _ :: Nil => u - case h1 :: (t1@(h2 :: _)) => + case h1 :: (t1 @ (h2 :: _)) => if (ordA.equiv(h1, h2)) nub(t1) else h1 :: nub(t1) } @@ -240,7 +228,7 @@ object SetOps { def top: Option[(A, B)] = (sa.top, sb.top) match { case (Some(a), Some(b)) => Some((a, b)) - case _ => None + case _ => None } def isTop(a: (A, B)): Boolean = @@ -304,15 +292,15 @@ object SetOps { def unifyUnion(u: List[(A, B)]): List[(A, B)] = { def step[X, Y](u: List[(X, Y)], sy: SetOps[Y]): Option[List[(X, Y)]] = { var change = false - val u1 = u.groupBy(_._1) + val u1 = u + .groupBy(_._1) .iterator .flatMap { case (x, xys) => val uy = sy.unifyUnion(xys.map(_._2)) if (uy.size < xys.size) { change = true uy.map((x, _)) - } - else xys + } else xys } .toList @@ -324,7 +312,7 @@ object SetOps { step(u, sb) match { case None => step(u.map(_.swap), sa) match { - case None => u + case None => u case Some(u2) => // we got a change unifying a loop(u2.map(_.swap)) diff --git a/core/src/main/scala/org/bykn/bosatsu/pattern/Splitter.scala b/core/src/main/scala/org/bykn/bosatsu/pattern/Splitter.scala index 14c957cba..30943af0d 100644 --- a/core/src/main/scala/org/bykn/bosatsu/pattern/Splitter.scala +++ b/core/src/main/scala/org/bykn/bosatsu/pattern/Splitter.scala @@ -24,22 +24,29 @@ abstract class Splitter[-Elem, Item, Sequence, R] { } object Splitter { - def stringSplitter[R](fn: Char => R)(implicit m: Monoid[R]): Splitter[Char, Char, String, R] = + def stringSplitter[R]( + fn: Char => R + )(implicit m: Monoid[R]): Splitter[Char, Char, String, R] = new Splitter[Char, Char, String, R] { val matcher = Matcher.charMatcher .mapWithInput { (s, _) => fn(s) } val monoidResult = m - def positions(c: Char): String => LazyList[(String, Char, R, String)] = { str => - def loop(init: Int): LazyList[(String, Char, R, String)] = - if (init >= str.length) LazyList.empty - else if (str.charAt(init) == c) { - (str.substring(0, init), c, fn(c), str.substring(init + 1)) #:: loop(init + 1) - } - else loop(init + 1) - - loop(0) + def positions(c: Char): String => LazyList[(String, Char, R, String)] = { + str => + def loop(init: Int): LazyList[(String, Char, R, String)] = + if (init >= str.length) LazyList.empty + else if (str.charAt(init) == c) { + ( + str.substring(0, init), + c, + fn(c), + str.substring(init + 1) + ) #:: loop(init + 1) + } else loop(init + 1) + + loop(0) } def anySplits(str: String): LazyList[(String, Char, R, String)] = @@ -74,12 +81,15 @@ object Splitter { val matchFn = matcher(c) { (str: List[V]) => - def loop(tail: List[V], acc: List[V]): LazyList[(List[V], V, R, List[V])] = + def loop( + tail: List[V], + acc: List[V] + ): LazyList[(List[V], V, R, List[V])] = tail match { case Nil => LazyList.empty case h :: t => matchFn(h) match { - case None => loop(t, h :: acc) + case None => loop(t, h :: acc) case Some(r) => (acc.reverse, h, r, t) #:: loop(t, h :: acc) } } @@ -92,7 +102,8 @@ object Splitter { def loop(str: List[V], acc: List[V]): LazyList[(List[V], V, R, List[V])] = str match { case Nil => LazyList.empty - case h :: t => (acc.reverse, h, monoidResult.empty, t) #:: loop(t, h :: acc) + case h :: t => + (acc.reverse, h, monoidResult.empty, t) #:: loop(t, h :: acc) } loop(str, Nil) } @@ -111,7 +122,9 @@ object Splitter { final override def fromList(cs: List[V]) = cs } - def listSplitter[P, V, R](m: Matcher[P, V, R])(implicit mon: Monoid[R]): Splitter[P, V, List[V], R] = + def listSplitter[P, V, R]( + m: Matcher[P, V, R] + )(implicit mon: Monoid[R]): Splitter[P, V, List[V], R] = new ListSplitter[P, V, R] { val matcher = m val monoidResult = mon diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/DataRepr.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/DataRepr.scala index db7f7c593..c17f2ac9d 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/DataRepr.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/DataRepr.scala @@ -1,9 +1,7 @@ package org.bykn.bosatsu.rankn -/** - * How is a non-external data type - * represented - */ +/** How is a non-external data type represented + */ sealed abstract class DataRepr @@ -13,7 +11,8 @@ object DataRepr { case object ZeroNat extends Nat(true) case object SuccNat extends Nat(false) - case class Enum(variant: Int, arity: Int, familyArities: List[Int]) extends DataRepr + case class Enum(variant: Int, arity: Int, familyArities: List[Int]) + extends DataRepr // a struct with arity 1 can be elided, and is called a new-type case class Struct(arity: Int) extends DataRepr { require(arity != 1) diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/DefinedType.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/DefinedType.scala index 38d9d462b..52f47520c 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/DefinedType.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/DefinedType.scala @@ -10,30 +10,28 @@ import cats.implicits._ import cats.data.NonEmptyList final case class DefinedType[+A]( - packageName: PackageName, - name: TypeName, - annotatedTypeParams: List[(Type.Var.Bound, A)], - constructors: List[ConstructorFn]) { + packageName: PackageName, + name: TypeName, + annotatedTypeParams: List[(Type.Var.Bound, A)], + constructors: List[ConstructorFn] +) { val typeParams: List[Type.Var.Bound] = annotatedTypeParams.map(_._1) require(typeParams.distinct == typeParams, typeParams.toString) - /** - * This is not the full type, since the full type - * has a ForAll(typeParams, ... in front if the - * typeParams is nonEmpty - */ + /** This is not the full type, since the full type has a ForAll(typeParams, + * ... in front if the typeParams is nonEmpty + */ val toTypeConst: Type.Const.Defined = DefinedType.toTypeConst(packageName, name) val toTypeTyConst: Type.TyConst = Type.TyConst(toTypeConst) - /** - * A type with exactly one constructor is a struct - */ + /** A type with exactly one constructor is a struct + */ def isStruct: Boolean = dataFamily == DataFamily.Struct val dataRepr: Constructor => DataRepr = @@ -47,13 +45,11 @@ final case class DefinedType[+A]( val zero = c0.name { cons => if (cons == zero) DataRepr.ZeroNat else DataRepr.SuccNat } - } - else if (c1.isZeroArg && c0.hasSingleArgType(toTypeTyConst)) { + } else if (c1.isZeroArg && c0.hasSingleArgType(toTypeTyConst)) { val zero = c1.name { cons => if (cons == zero) DataRepr.ZeroNat else DataRepr.SuccNat } - } - else { + } else { val famArities = c0.arity :: c1.arity :: Nil val zero = c0.name val zrep = DataRepr.Enum(0, c0.arity, famArities) @@ -63,7 +59,9 @@ final case class DefinedType[+A]( } case cons => val famArities = cons.map(_.arity) - val mapping = cons.zipWithIndex.map { case (c, idx) => c.name -> DataRepr.Enum(idx, c.arity, famArities) }.toMap + val mapping = cons.zipWithIndex.map { case (c, idx) => + c.name -> DataRepr.Enum(idx, c.arity, famArities) + }.toMap mapping } @@ -76,12 +74,15 @@ final case class DefinedType[+A]( case c0 :: c1 :: Nil => // exactly two constructor functions if (c0.isZeroArg && c1.hasSingleArgType(toTypeTyConst)) DataFamily.Nat - else if (c1.isZeroArg && c0.hasSingleArgType(toTypeTyConst)) DataFamily.Nat + else if (c1.isZeroArg && c0.hasSingleArgType(toTypeTyConst)) + DataFamily.Nat else DataFamily.Enum case _ => DataFamily.Enum - } + } - private def toAnnotatedKinds(implicit ev: A <:< Kind.Arg): List[(Type.Var.Bound, Kind.Arg)] = { + private def toAnnotatedKinds(implicit + ev: A <:< Kind.Arg + ): List[(Type.Var.Bound, Kind.Arg)] = { type L[+X] = List[(Type.Var.Bound, X)] ev.substituteCo[L](annotatedTypeParams) } @@ -91,11 +92,11 @@ final case class DefinedType[+A]( val tc: Type = Type.const(packageName, name) val res = typeParams.foldLeft(tc) { (res, v) => - Type.TyApply(res, Type.TyVar(v)) - } + Type.TyApply(res, Type.TyVar(v)) + } val resT = NonEmptyList.fromList(cf.args.map(_._2)) match { case Some(nel) => Type.Fun(nel, res) - case None => res + case None => res } val typeArgs = toAnnotatedKinds.map { case (b, ka) => (b, ka.kind) } Type.forAll(typeArgs, resT) @@ -109,24 +110,33 @@ object DefinedType { def toTypeConst(pn: PackageName, nm: TypeName): Type.Const.Defined = Type.Const.Defined(pn, nm) - def listToMap[A](dts: List[DefinedType[A]]): SortedMap[(PackageName, TypeName), DefinedType[A]] = + def listToMap[A]( + dts: List[DefinedType[A]] + ): SortedMap[(PackageName, TypeName), DefinedType[A]] = SortedMap(dts.map { dt => (dt.packageName, dt.name) -> dt }: _*) - def toKindMap[F[_]: Foldable](dts: F[DefinedType[Kind.Arg]]): Map[Type.Const.Defined, Kind] = - dts.foldLeft( - Map.newBuilder[Type.Const.Defined, Kind] - ) { (b, dt) => b += ((dt.toTypeConst.toDefined, dt.kindOf)) } - .result() + def toKindMap[F[_]: Foldable]( + dts: F[DefinedType[Kind.Arg]] + ): Map[Type.Const.Defined, Kind] = + dts + .foldLeft( + Map.newBuilder[Type.Const.Defined, Kind] + ) { (b, dt) => b += ((dt.toTypeConst.toDefined, dt.kindOf)) } + .result() implicit val definedTypeTraverse: Traverse[DefinedType] = new Traverse[DefinedType] { val listTup = Traverse[List].compose[(Type.Var.Bound, *)] - def traverse[F[_]: Applicative, A, B](da: DefinedType[A])(fn: A => F[B]): F[DefinedType[B]] = + def traverse[F[_]: Applicative, A, B]( + da: DefinedType[A] + )(fn: A => F[B]): F[DefinedType[B]] = listTup.traverse(da.annotatedTypeParams)(fn).map { ap => da.copy(annotatedTypeParams = ap) } - def foldRight[A, B](fa: DefinedType[A], b: Eval[B])(fn: (A, Eval[B]) => Eval[B]): Eval[B] = + def foldRight[A, B](fa: DefinedType[A], b: Eval[B])( + fn: (A, Eval[B]) => Eval[B] + ): Eval[B] = listTup.foldRight(fa.annotatedTypeParams, b)(fn) def foldLeft[A, B](fa: DefinedType[A], b: B)(fn: (B, A) => B): B = diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala index e3c55b2e5..c5d691f77 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala @@ -16,7 +16,8 @@ import org.bykn.bosatsu.{ Region, RecursionKind, TypedExpr, - Variance} + Variance +} import scala.collection.mutable.{Map => MMap} @@ -40,15 +41,17 @@ sealed abstract class Infer[+A] { Infer.Impl.MapEither(this, fn) final def runVar( - v: Map[Infer.Name, Type], - tpes: Map[(PackageName, Constructor), Infer.Cons], - kinds: Map[Type.Const.Defined, Kind]): RefSpace[Either[Error, A]] = + v: Map[Infer.Name, Type], + tpes: Map[(PackageName, Constructor), Infer.Cons], + kinds: Map[Type.Const.Defined, Kind] + ): RefSpace[Either[Error, A]] = Infer.Env.init(v, tpes, kinds).flatMap(run(_)) final def runFully( - v: Map[Infer.Name, Type], - tpes: Map[(PackageName, Constructor), Infer.Cons], - kinds: Map[Type.Const.Defined, Kind]): Either[Error, A] = + v: Map[Infer.Name, Type], + tpes: Map[(PackageName, Constructor), Infer.Cons], + kinds: Map[Type.Const.Defined, Kind] + ): Either[Error, A] = runVar(v, tpes, kinds).run.value } @@ -68,68 +71,77 @@ object Infer { TailRecM(a, fn) } - - /** - * The first element of the tuple are the the bound type - * vars for this type. - * the next are the types of the args of the constructor - * the final is the defined type this creates - */ + /** The first element of the tuple are the the bound type vars for this type. + * the next are the types of the args of the constructor the final is the + * defined type this creates + */ type Cons = (List[(Type.Var.Bound, Kind.Arg)], List[Type], Type.Const.Defined) type Name = (Option[PackageName], Identifier) class Env( - val uniq: Ref[Long], - val vars: Map[Name, Type], - val typeCons: Map[(PackageName, Constructor), Cons], - val variances: Map[Type.Const.Defined, Kind]) { + val uniq: Ref[Long], + val vars: Map[Name, Type], + val typeCons: Map[(PackageName, Constructor), Cons], + val variances: Map[Type.Const.Defined, Kind] + ) { def addVars(vt: List[(Name, Type)]): Env = new Env(uniq, vars = vars ++ vt, typeCons, variances) - private[this] val kindCache: MMap[(Type, Map[Type.Var.Bound, Kind]), Either[Region => Error, Kind]] = + private[this] val kindCache: MMap[(Type, Map[Type.Var.Bound, Kind]), Either[ + Region => Error, + Kind + ]] = MMap() def getKind(t: Type, region: Region): Either[Error, Kind] = { - def loop(item: Type, locals: Map[Type.Var.Bound, Kind]): Either[Region => Error, Kind] = + def loop( + item: Type, + locals: Map[Type.Var.Bound, Kind] + ): Either[Region => Error, Kind] = item match { case Type.TyVar(b @ Type.Var.Bound(_)) => // don't cache locals, there is no point locals.get(b) match { case Some(k) => Right(k) - case None => + case None => // $COVERAGE-OFF$ this should be unreachable because all vars should have a known kind - Left({ region => Error.UnknownKindOfVar(t, region, s"unbound var: $b") }) - // $COVERAGE-ON$ this should be unreachable + Left({ region => + Error.UnknownKindOfVar(t, region, s"unbound var: $b") + }) + // $COVERAGE-ON$ this should be unreachable } case Type.TyVar(Type.Var.Skolem(_, kind, _)) => Right(kind) - case Type.TyMeta(Type.Meta(kind, _, _)) => Right(kind) + case Type.TyMeta(Type.Meta(kind, _, _)) => Right(kind) case Type.TyConst(const) => val d = const.toDefined // some tests rely on syntax without importing // TODO remove this variances.get(d).orElse(Type.builtInKinds.get(d)) match { case Some(ks) => Right(ks) - case None => Left({region => Error.UnknownDefined(d, region) }) + case None => Left({ region => Error.UnknownDefined(d, region) }) } case _ => - kindCache.getOrElseUpdate((item, locals), + kindCache.getOrElseUpdate( + (item, locals), item match { case Type.ForAll(bound, t) => loop(t, locals ++ bound.toList) - case ap@Type.TyApply(left, right) => + case ap @ Type.TyApply(left, right) => loop(left, locals) .product(loop(right, locals)) - .flatMap { - case (leftKind, rhs) => - Kind.validApply[Region => Error](leftKind, rhs, - { region => - Error.KindCannotTyApply(ap, region) - }) { cons => - { region => - Error.KindInvalidApply(ap, cons, rhs, region) - } + .flatMap { case (leftKind, rhs) => + Kind.validApply[Region => Error]( + leftKind, + rhs, + { region => + Error.KindCannotTyApply(ap, region) } + ) { cons => + { region => + Error.KindInvalidApply(ap, cons, rhs, region) + } + } } // $COVERAGE-OFF$ this should be unreachable because we handle Var above case _ => @@ -145,9 +157,10 @@ object Infer { object Env { def init( - vars: Map[Name, Type], - tpes: Map[(PackageName, Constructor), Cons], - kinds: Map[Type.Const.Defined, Kind]): RefSpace[Env] = + vars: Map[Name, Type], + tpes: Map[(PackageName, Constructor), Cons], + kinds: Map[Type.Const.Defined, Kind] + ): RefSpace[Env] = RefSpace.newRef(0L).map(new Env(_, vars, tpes, kinds)) } @@ -171,7 +184,7 @@ object Infer { def lookupVarType(v: Name, reg: Region): Infer[Type] = getEnv.flatMap { env => env.get(v) match { - case None => fail(Error.VarNotInScope(v, env, reg)) + case None => fail(Error.VarNotInScope(v, env, reg)) case Some(t) => pure(t) } } @@ -180,72 +193,134 @@ object Infer { object Error { - /** - * These are errors in the ability to type the code - * Generally these cannot be caught by other phases - */ + /** These are errors in the ability to type the code Generally these cannot + * be caught by other phases + */ sealed abstract class TypeError extends Error - case class NotUnifiable(left: Type, right: Type, leftRegion: Region, rightRegion: Region) extends TypeError - case class KindNotUnifiable(leftK: Kind, leftT: Type, rightK: Kind, rightT: Type, leftRegion: Region, rightRegion: Region) extends TypeError - case class KindInvalidApply(typeApply: Type.TyApply, leftK: Kind.Cons, rightK: Kind, region: Region) extends TypeError - case class KindMetaMismatch(meta: Type.TyMeta, inferred: Type.Tau, inferredKind: Kind, metaRegion: Region, inferredRegion: Region) extends TypeError - case class KindCannotTyApply(ap: Type.TyApply, region: Region) extends TypeError - case class UnknownDefined(tpe: Type.Const.Defined, region: Region) extends TypeError - case class NotPolymorphicEnough(tpe: Type, in: Expr[_], badTvs: NonEmptyList[Type.Var.Skolem], reg: Region) extends TypeError - case class SubsumptionCheckFailure(inferred: Type, declared: Type, infRegion: Region, decRegion: Region, badTvs: NonEmptyList[Type.Var]) extends TypeError + case class NotUnifiable( + left: Type, + right: Type, + leftRegion: Region, + rightRegion: Region + ) extends TypeError + case class KindNotUnifiable( + leftK: Kind, + leftT: Type, + rightK: Kind, + rightT: Type, + leftRegion: Region, + rightRegion: Region + ) extends TypeError + case class KindInvalidApply( + typeApply: Type.TyApply, + leftK: Kind.Cons, + rightK: Kind, + region: Region + ) extends TypeError + case class KindMetaMismatch( + meta: Type.TyMeta, + inferred: Type.Tau, + inferredKind: Kind, + metaRegion: Region, + inferredRegion: Region + ) extends TypeError + case class KindCannotTyApply(ap: Type.TyApply, region: Region) + extends TypeError + case class UnknownDefined(tpe: Type.Const.Defined, region: Region) + extends TypeError + case class NotPolymorphicEnough( + tpe: Type, + in: Expr[_], + badTvs: NonEmptyList[Type.Var.Skolem], + reg: Region + ) extends TypeError + case class SubsumptionCheckFailure( + inferred: Type, + declared: Type, + infRegion: Region, + decRegion: Region, + badTvs: NonEmptyList[Type.Var] + ) extends TypeError // this sounds internal but can be due to an infinite type attempted to be defined - case class UnexpectedMeta(m: Type.Meta, in: Type, left: Region, right: Region) extends TypeError - case class ArityMismatch(leftArity: Int, leftRegion: Region, rightArity: Int, rightRegion: Region) extends TypeError - case class ArityTooLarge(arity: Int, maxArity: Int, region: Region) extends TypeError - - /** - * These are errors that prevent typing due to unknown names, - * They could be caught in a phase that collects all the naming errors - */ + case class UnexpectedMeta( + m: Type.Meta, + in: Type, + left: Region, + right: Region + ) extends TypeError + case class ArityMismatch( + leftArity: Int, + leftRegion: Region, + rightArity: Int, + rightRegion: Region + ) extends TypeError + case class ArityTooLarge(arity: Int, maxArity: Int, region: Region) + extends TypeError + + /** These are errors that prevent typing due to unknown names, They could be + * caught in a phase that collects all the naming errors + */ sealed abstract class NameError extends Error // This could be a user error if we don't check scoping before typing - case class VarNotInScope(varName: Name, vars: Map[Name, Type], region: Region) extends NameError + case class VarNotInScope( + varName: Name, + vars: Map[Name, Type], + region: Region + ) extends NameError // This could be a user error if we don't check scoping before typing - case class UnexpectedBound(v: Type.Var.Bound, in: Type, rb: Region, rt: Region) extends NameError - case class UnknownConstructor(name: (PackageName, Constructor), region: Region, env: Env) extends NameError { - def knownConstructors: List[(PackageName, Constructor)] = env.typeCons.keys.toList.sorted + case class UnexpectedBound( + v: Type.Var.Bound, + in: Type, + rb: Region, + rt: Region + ) extends NameError + case class UnknownConstructor( + name: (PackageName, Constructor), + region: Region, + env: Env + ) extends NameError { + def knownConstructors: List[(PackageName, Constructor)] = + env.typeCons.keys.toList.sorted } - case class UnionPatternBindMismatch(pattern: Pattern, names: NonEmptyList[List[Identifier.Bindable]], region: Region) extends NameError - - /** - * These can only happen if the compiler has bugs at some point - */ + case class UnionPatternBindMismatch( + pattern: Pattern, + names: NonEmptyList[List[Identifier.Bindable]], + region: Region + ) extends NameError + + /** These can only happen if the compiler has bugs at some point + */ sealed abstract class InternalError extends Error { def message: String def region: Region } // This is a logic error which should never happen - case class InferIncomplete(method: String, term: Expr[_], region: Region) extends InternalError { + case class InferIncomplete(method: String, term: Expr[_], region: Region) + extends InternalError { // $COVERAGE-OFF$ we don't test these messages, maybe they should be removed def message = s"$method not complete for $term" // $COVERAGE-ON$ we don't test these messages, maybe they should be removed } - case class ExpectedRho(tpe: Type, context: String, region: Region) extends InternalError { + case class ExpectedRho(tpe: Type, context: String, region: Region) + extends InternalError { // $COVERAGE-OFF$ we don't test these messages, maybe they should be removed def message = s"expected $tpe to be a Type.Rho, at $context" // $COVERAGE-ON$ we don't test these messages, maybe they should be removed } - case class UnknownKindOfVar(tpe: Type, region: Region, mess: String) extends InternalError { + case class UnknownKindOfVar(tpe: Type, region: Region, mess: String) + extends InternalError { // $COVERAGE-OFF$ we don't test these messages, maybe they should be removed def message = s"unknown var in $tpe: $mess at $region" // $COVERAGE-ON$ we don't test these messages, maybe they should be removed } } - - /** - * This is where the internal implementation goes. - * It is here to make it easy to make one block private - * and not do so on every little helper function - */ + /** This is where the internal implementation goes. It is here to make it easy + * to make one block private and not do so on every little helper function + */ private object Impl { sealed abstract class Expected[A] object Expected { @@ -260,7 +335,7 @@ object Infer { def run(env: Env) = fa.run(env).flatMap { case Left(msg) => RefSpace.pure(Left(msg)) - case Right(a) => fn(a).run(env) + case Right(a) => fn(a).run(env) } } case class Peek[A](fa: Infer[A]) extends Infer[Either[Error, A]] { @@ -272,22 +347,24 @@ object Infer { // $COVERAGE-ON$ this should be unreachable } } - case class MapEither[A, B](fa: Infer[A], fn: A => Either[Error, B]) extends Infer[B] { + case class MapEither[A, B](fa: Infer[A], fn: A => Either[Error, B]) + extends Infer[B] { def run(env: Env) = fa.run(env).flatMap { case Left(msg) => RefSpace.pure(Left(msg)) - case Right(a) => RefSpace.pure(fn(a)) + case Right(a) => RefSpace.pure(fn(a)) } } // $COVERAGE-OFF$ needed for Monad, but not actually used - case class TailRecM[A, B](init: A, fn: A => Infer[Either[A, B]]) extends Infer[B] { + case class TailRecM[A, B](init: A, fn: A => Infer[Either[A, B]]) + extends Infer[B] { def run(env: Env) = { // RefSpace uses Eval so this is fine, if not maybe the fastest thing ever def loop(a: A): RefSpace[Either[Error, B]] = fn(a).run(env).flatMap { - case Left(err) => RefSpace.pure(Left(err)) - case Right(Left(a)) => loop(a) + case Left(err) => RefSpace.pure(Left(err)) + case Right(Left(a)) => loop(a) case Right(Right(b)) => RefSpace.pure(Right(b)) } loop(init) @@ -309,7 +386,8 @@ object Infer { } } - case class ExtendEnvs[A](vt: List[(Name, Type)], in: Infer[A]) extends Infer[A] { + case class ExtendEnvs[A](vt: List[(Name, Type)], in: Infer[A]) + extends Infer[A] { def run(env: Env) = in.run(env.addVars(vt)) } @@ -335,8 +413,7 @@ object Infer { private val checkedKinds: Infer[Type => Option[Kind]] = { val emptyRegion = Region(0, 0) GetEnv.map { env => - - { tpe => env.getKind(tpe, emptyRegion).toOption } + { tpe => env.getKind(tpe, emptyRegion).toOption } } } @@ -346,26 +423,31 @@ object Infer { kindOf(ta.on, region) .flatMap(varianceOfConsKind(ta, _, region)) - def varianceOfConsKind(ta: Type.TyApply, k: Kind, region: Region): Infer[Variance] = + def varianceOfConsKind( + ta: Type.TyApply, + k: Kind, + region: Region + ): Infer[Variance] = k match { case Kind.Cons(Kind.Arg(v, _), _) => pure(v) case Kind.Type => fail(Error.KindCannotTyApply(ta, region)) } - - /** - * Skolemize on a function just recurses on the result type. - * - * Skolemize replaces ForAll parameters with skolem variables - * and then skolemizes recurses on the substituted value - * - * otherwise we return the type. - * - * The returned type is in weak-prenex form: all ForAlls have - * been floated up over covariant parameters - */ - private def skolemize(t: Type, region: Region): Infer[(List[Type.Var.Skolem], Type.Rho)] = + /** Skolemize on a function just recurses on the result type. + * + * Skolemize replaces ForAll parameters with skolem variables and then + * skolemizes recurses on the substituted value + * + * otherwise we return the type. + * + * The returned type is in weak-prenex form: all ForAlls have been floated + * up over covariant parameters + */ + private def skolemize( + t: Type, + region: Region + ): Infer[(List[Type.Var.Skolem], Type.Rho)] = t match { case Type.ForAll(tvs, ty) => // Rule PRPOLY @@ -375,7 +457,7 @@ object Infer { sks2ty <- skolemize(substTyRho(tvs.map(_._1), sksT)(ty), region) (sks2, ty2) = sks2ty } yield (sks1.toList ::: sks2, ty2) - case ta@Type.TyApply(left, right) => + case ta @ Type.TyApply(left, right) => // Rule PRFUN // we know the kind of left is k -> x, and right has kind k varianceOfCons(ta, region) @@ -410,10 +492,9 @@ object Infer { } } - /** - * This fills in any meta vars that have been - * quantified and replaces them with what they point to - */ + /** This fills in any meta vars that have been quantified and replaces them + * with what they point to + */ def zonkType(t: Type): Infer[Type] = Type.zonkMeta(t)(zonk(_)) @@ -423,13 +504,20 @@ object Infer { def initRef[A](err: Error): Infer[Ref[Either[Error, A]]] = lift(RefSpace.newRef[Either[Error, A]](Left(err))) - def substTyRho(keys: NonEmptyList[Type.Var], vals: NonEmptyList[Type.Rho]): Type.Rho => Type.Rho = { + def substTyRho( + keys: NonEmptyList[Type.Var], + vals: NonEmptyList[Type.Rho] + ): Type.Rho => Type.Rho = { val env = keys.toList.iterator.zip(vals.toList.iterator).toMap { t => Type.substituteRhoVar(t, env) } } - def substTyExpr[A](keys: NonEmptyList[Type.Var], vals: NonEmptyList[Type.Rho], expr: TypedExpr[A]): TypedExpr[A] = { + def substTyExpr[A]( + keys: NonEmptyList[Type.Var], + vals: NonEmptyList[Type.Rho], + expr: TypedExpr[A] + ): TypedExpr[A] = { val fn = Type.substTy(keys, vals) expr.traverseType[cats.Id](fn) } @@ -443,7 +531,11 @@ object Infer { * new meta variables for each bound variable in ForAll or skolemize * which replaces the ForAll variables with skolem variables */ - def assertRho(t: Type, context: => String, region: Region): Infer[Type.Rho] = + def assertRho( + t: Type, + context: => String, + region: Region + ): Infer[Type.Rho] = t match { case r: Type.Rho => pure(r) // $COVERAGE-OFF$ this should be unreachable @@ -461,7 +553,8 @@ object Infer { // TODO: it may be possible to improve type checking // by pushing foralls into covariant constructors // but it's not trivial - vars.traverse { case (_, k) => newMetaType(k) } + vars + .traverse { case (_, k) => newMetaType(k) } .map { vars1T => substTyRho(vars.map(_._1), vars1T)(rho) } @@ -471,11 +564,20 @@ object Infer { /* * Invariant: r2 needs to be in weak prenex form */ - def subsCheckFn(a1s: NonEmptyList[Type], r1: Type, a2s: NonEmptyList[Type], r2: Type.Rho, left: Region, right: Region): Infer[TypedExpr.Coerce] = + def subsCheckFn( + a1s: NonEmptyList[Type], + r1: Type, + a2s: NonEmptyList[Type], + r2: Type.Rho, + left: Region, + right: Region + ): Infer[TypedExpr.Coerce] = // note due to contravariance in input, we reverse the order there for { // we know that they have the same length because we have already called unifyFn - coarg <- a2s.zip(a1s).traverse { case (a2, a1) => subsCheck(a2, a1, left, right) } + coarg <- a2s.zip(a1s).traverse { case (a2, a1) => + subsCheck(a2, a1, left, right) + } // r2 is already in weak-prenex form cores <- subsCheckRho(r1, r2, left, right) ks <- checkedKinds @@ -488,95 +590,119 @@ object Infer { * was rewritten to: * forall a, b. a -> (b -> b) */ - def subsCheckRho(t: Type, rho: Type.Rho, left: Region, right: Region): Infer[TypedExpr.Coerce] = + def subsCheckRho( + t: Type, + rho: Type.Rho, + left: Region, + right: Region + ): Infer[TypedExpr.Coerce] = (t, rho) match { - case (fa@Type.ForAll(_, _), rho) => + case (fa @ Type.ForAll(_, _), rho) => // Rule SPEC instantiate(fa, left).flatMap(subsCheckRho(_, rho, left, right)) case (rhot: Type.Rho, rho) => subsCheckRho2(rhot, rho, left, right) } - def subsCheckRho2(t: Type.Rho, rho: Type.Rho, left: Region, right: Region): Infer[TypedExpr.Coerce] = + def subsCheckRho2( + t: Type.Rho, + rho: Type.Rho, + left: Region, + right: Region + ): Infer[TypedExpr.Coerce] = // get the kinds to make sure they are well kindinded kindOf(t, left).product(kindOf(rho, right)) *> - ((t, rho) match { - case (rho1, Type.Fun(a2, r2)) => - // Rule FUN - for { - a1r1 <- unifyFn(a2.length, rho1, left, right) - (a1, r1) = a1r1 - // since rho is in weak prenex form, and Fun is covariant on r2, we know - // r2 is in weak-prenex form and a rho type - rhor2 <- assertRho(r2, s"subsCheckRho($t, $rho, $left, $right), line 462", right) - coerce <- subsCheckFn(a1, r1, a2, rhor2, left, right) - } yield coerce - case (Type.Fun(a1, r1), rho2) => - // Rule FUN - for { - a2r2 <- unifyFn(a1.length, rho2, right, left) - (a2, r2) = a2r2 - // since rho is in weak prenex form, and Fun is covariant on r2, we know - // r2 is in weak-prenex form - rhor2 <- assertRho(r2, s"subsCheckRho($t, $rho, $left, $right), line 471", right) - coerce <- subsCheckFn(a1, r1, a2, rhor2, left, right) - } yield coerce - case (rho1, ta@Type.TyApply(l2, r2)) => - for { - kl <- kindOf(l2, right) - kr <- kindOf(r2, right) - l1r1 <- unifyTyApp(rho1, kl, kr, left, right) - (l1, r1) = l1r1 - _ <- varianceOfConsKind(ta, kl, right).flatMap { - case Variance.Covariant => - subsCheck(r1, r2, left, right).void - case Variance.Contravariant => - subsCheck(r2, r1, right, left).void - case Variance.Phantom => - // this doesn't matter - unit - case Variance.Invariant => - unifyType(r1, r2, left, right) - } - // should we coerce to t2? Seems like... but copying previous code - _ <- subsCheck(l1, l2, left, right) - ks <- checkedKinds - } yield TypedExpr.coerceRho(rho1, ks) - case (ta@Type.TyApply(l1, r1), rho2) => - for { - kl <- kindOf(l1, left) - kr <- kindOf(r1, left) - l2r2 <- unifyTyApp(rho2, kl, kr, left, right) - (l2, r2) = l2r2 - _ <- varianceOfConsKind(ta, kl, left).flatMap { - case Variance.Covariant => - subsCheck(r1, r2, left, right).void - case Variance.Contravariant => - subsCheck(r2, r1, right, left).void - case Variance.Phantom => - // this doesn't matter - unit - case Variance.Invariant => - unifyType(r1, r2, left, right) - } - _ <- subsCheck(l1, l2, left, right) - ks <- checkedKinds - // should we coerce to t2? Seems like... but copying previous code - } yield TypedExpr.coerceRho(ta, ks) - case (t1, t2) => - // rule: MONO - unify(t1, t2, left, right) *> checkedKinds.map(TypedExpr.coerceRho(t1, _)) // TODO this coerce seems right, since we have unified - }) + ((t, rho) match { + case (rho1, Type.Fun(a2, r2)) => + // Rule FUN + for { + a1r1 <- unifyFn(a2.length, rho1, left, right) + (a1, r1) = a1r1 + // since rho is in weak prenex form, and Fun is covariant on r2, we know + // r2 is in weak-prenex form and a rho type + rhor2 <- assertRho( + r2, + s"subsCheckRho($t, $rho, $left, $right), line 462", + right + ) + coerce <- subsCheckFn(a1, r1, a2, rhor2, left, right) + } yield coerce + case (Type.Fun(a1, r1), rho2) => + // Rule FUN + for { + a2r2 <- unifyFn(a1.length, rho2, right, left) + (a2, r2) = a2r2 + // since rho is in weak prenex form, and Fun is covariant on r2, we know + // r2 is in weak-prenex form + rhor2 <- assertRho( + r2, + s"subsCheckRho($t, $rho, $left, $right), line 471", + right + ) + coerce <- subsCheckFn(a1, r1, a2, rhor2, left, right) + } yield coerce + case (rho1, ta @ Type.TyApply(l2, r2)) => + for { + kl <- kindOf(l2, right) + kr <- kindOf(r2, right) + l1r1 <- unifyTyApp(rho1, kl, kr, left, right) + (l1, r1) = l1r1 + _ <- varianceOfConsKind(ta, kl, right).flatMap { + case Variance.Covariant => + subsCheck(r1, r2, left, right).void + case Variance.Contravariant => + subsCheck(r2, r1, right, left).void + case Variance.Phantom => + // this doesn't matter + unit + case Variance.Invariant => + unifyType(r1, r2, left, right) + } + // should we coerce to t2? Seems like... but copying previous code + _ <- subsCheck(l1, l2, left, right) + ks <- checkedKinds + } yield TypedExpr.coerceRho(rho1, ks) + case (ta @ Type.TyApply(l1, r1), rho2) => + for { + kl <- kindOf(l1, left) + kr <- kindOf(r1, left) + l2r2 <- unifyTyApp(rho2, kl, kr, left, right) + (l2, r2) = l2r2 + _ <- varianceOfConsKind(ta, kl, left).flatMap { + case Variance.Covariant => + subsCheck(r1, r2, left, right).void + case Variance.Contravariant => + subsCheck(r2, r1, right, left).void + case Variance.Phantom => + // this doesn't matter + unit + case Variance.Invariant => + unifyType(r1, r2, left, right) + } + _ <- subsCheck(l1, l2, left, right) + ks <- checkedKinds + // should we coerce to t2? Seems like... but copying previous code + } yield TypedExpr.coerceRho(ta, ks) + case (t1, t2) => + // rule: MONO + unify(t1, t2, left, right) *> checkedKinds.map( + TypedExpr.coerceRho(t1, _) + ) // TODO this coerce seems right, since we have unified + }) /* * Invariant: if the second argument is (Check rho) then rho is in weak prenex form */ - def instSigma(sigma: Type, expect: Expected[(Type.Rho, Region)], r: Region): Infer[TypedExpr.Coerce] = + def instSigma( + sigma: Type, + expect: Expected[(Type.Rho, Region)], + r: Region + ): Infer[TypedExpr.Coerce] = expect match { case Expected.Check((t, tr)) => // note t is in weak-prenex form subsCheckRho(sigma, t, r, tr) - case infer@Expected.Inf(_) => + case infer @ Expected.Inf(_) => for { rho <- instantiate(sigma, r) _ <- infer.set((rho, r)) @@ -584,19 +710,26 @@ object Infer { } yield TypedExpr.coerceRho(rho, ks) } - def unifyFn(arity: Int, fnType: Type.Rho, fnRegion: Region, evidenceRegion: Region): Infer[(NonEmptyList[Type], Type)] = + def unifyFn( + arity: Int, + fnType: Type.Rho, + fnRegion: Region, + evidenceRegion: Region + ): Infer[(NonEmptyList[Type], Type)] = fnType match { case Type.Fun(arg, res) => val fnArity = arg.length if (fnArity == arity) pure((arg, res)) - else fail(Error.ArityMismatch(fnArity, fnRegion, arity, evidenceRegion)) + else + fail(Error.ArityMismatch(fnArity, fnRegion, arity, evidenceRegion)) case tau => val args = if (Type.FnType.ValidArity.unapply(arity)) { pure(NonEmptyList.fromListUnsafe((1 to arity).toList)) - } - else { - fail(Error.ArityTooLarge(arity, Type.FnType.MaxSize, evidenceRegion)) + } else { + fail( + Error.ArityTooLarge(arity, Type.FnType.MaxSize, evidenceRegion) + ) } for { sized <- args @@ -606,23 +739,50 @@ object Infer { } yield (argT, resT) } - def unifyKind(kind1: Kind, tpe1: Type, kind2: Kind, tpe2: Type, region1: Region, region2: Region): Infer[Unit] = + def unifyKind( + kind1: Kind, + tpe1: Type, + kind2: Kind, + tpe2: Type, + region1: Region, + region2: Region + ): Infer[Unit] = // we may need to be tracking kinds and widen them... - if (Kind.leftSubsumesRight(kind1, kind2) || Kind.leftSubsumesRight(kind2, kind2)) unit - else fail(Error.KindNotUnifiable(kind1, tpe1, kind2, tpe2, region1, region2)) - - private def checkApply[A](apType: Type.TyApply, lKind: Kind, rKind: Kind, apRegion: Region)(next: Infer[A]): Infer[A] = - Kind.validApply[Error](lKind, rKind, - Error.KindCannotTyApply(apType, apRegion)) { cons => - Error.KindInvalidApply(apType, cons, rKind, apRegion) - } match { - case Right(_) => next - case Left(err) => fail(err) - } + if ( + Kind.leftSubsumesRight(kind1, kind2) || Kind.leftSubsumesRight( + kind2, + kind2 + ) + ) unit + else + fail(Error.KindNotUnifiable(kind1, tpe1, kind2, tpe2, region1, region2)) + + private def checkApply[A]( + apType: Type.TyApply, + lKind: Kind, + rKind: Kind, + apRegion: Region + )(next: Infer[A]): Infer[A] = + Kind.validApply[Error]( + lKind, + rKind, + Error.KindCannotTyApply(apType, apRegion) + ) { cons => + Error.KindInvalidApply(apType, cons, rKind, apRegion) + } match { + case Right(_) => next + case Left(err) => fail(err) + } - def unifyTyApp(apType: Type.Rho, lKind: Kind, rKind: Kind, apRegion: Region, evidenceRegion: Region): Infer[(Type, Type)] = + def unifyTyApp( + apType: Type.Rho, + lKind: Kind, + rKind: Kind, + apRegion: Region, + evidenceRegion: Region + ): Infer[(Type, Type)] = apType match { - case ap@Type.TyApply(left, right) => + case ap @ Type.TyApply(left, right) => checkApply(ap, lKind, rKind, apRegion)(pure((left, right))) case notApply => for { @@ -636,18 +796,30 @@ object Infer { } // invariant the flexible type variable tv1 is not bound - def unifyUnboundVar(m: Type.Meta, ty2: Type.Tau, left: Region, right: Region): Infer[Unit] = + def unifyUnboundVar( + m: Type.Meta, + ty2: Type.Tau, + left: Region, + right: Region + ): Infer[Unit] = ty2 match { - case meta2@Type.TyMeta(m2) => + case meta2 @ Type.TyMeta(m2) => readMeta(m2).flatMap { case Some(ty2) => unify(Type.TyMeta(m), ty2, left, right) case None => if (Kind.leftSubsumesRight(m.kind, m2.kind)) { // we have to check that the kind matches before writing to a meta writeMeta(m, ty2) - } - else { - fail(Error.KindMetaMismatch(Type.TyMeta(m), meta2, m2.kind, left, right)) + } else { + fail( + Error.KindMetaMismatch( + Type.TyMeta(m), + meta2, + m2.kind, + left, + right + ) + ) } } case nonMeta => @@ -661,18 +833,30 @@ object Infer { if (Kind.leftSubsumesRight(m.kind, nmk)) { // we have to check that the kind matches before writing to a meta writeMeta(m, nonMeta) - } - else { - fail(Error.KindMetaMismatch(Type.TyMeta(m), nonMeta, nmk, left, right)) + } else { + fail( + Error.KindMetaMismatch( + Type.TyMeta(m), + nonMeta, + nmk, + left, + right + ) + ) } } - } + } } } - def unifyVar(tv: Type.Meta, t: Type.Tau, left: Region, right: Region): Infer[Unit] = + def unifyVar( + tv: Type.Meta, + t: Type.Tau, + left: Region, + right: Region + ): Infer[Unit] = readMeta(tv).flatMap { - case None => unifyUnboundVar(tv, t, left, right) + case None => unifyUnboundVar(tv, t, left, right) case Some(ty1) => unify(ty1, t, left, right) } @@ -684,19 +868,17 @@ object Infer { case (Type.TyApply(a1, b1), Type.TyApply(a2, b2)) => unifyType(a1, a2, r1, r2) *> unifyType(b1, b2, r1, r2) case (Type.TyConst(c1), Type.TyConst(c2)) if c1 == c2 => unit - case (Type.TyVar(v1), Type.TyVar(v2)) if v1 == v2 => unit - case (Type.TyVar(b@Type.Var.Bound(_)), _) => + case (Type.TyVar(v1), Type.TyVar(v2)) if v1 == v2 => unit + case (Type.TyVar(b @ Type.Var.Bound(_)), _) => fail(Error.UnexpectedBound(b, t2, r1, r2)) - case (_, Type.TyVar(b@Type.Var.Bound(_))) => + case (_, Type.TyVar(b @ Type.Var.Bound(_))) => fail(Error.UnexpectedBound(b, t1, r2, r1)) case (left, right) => fail(Error.NotUnifiable(left, right, r1, r2)) } - /** - * for a type to be unified, we mean we can substitute in either - * direction - */ + /** for a type to be unified, we mean we can substitute in either direction + */ def unifyType(t1: Type, t2: Type, r1: Region, r2: Region): Infer[Unit] = (t1, t2) match { case (rho1: Type.Rho, rho2: Type.Rho) => @@ -705,10 +887,9 @@ object Infer { subsCheck(t1, t2, r1, r2) *> subsCheck(t2, t1, r2, r1).void } - /** - * Allocate a new Meta variable which - * will point to a Tau (no forall anywhere) type - */ + /** Allocate a new Meta variable which will point to a Tau (no forall + * anywhere) type + */ def newMetaType(kind: Kind): Infer[Type.TyMeta] = for { id <- nextId @@ -720,22 +901,24 @@ object Infer { def newSkolemTyVar(tv: Type.Var.Bound, kind: Kind): Infer[Type.Var.Skolem] = nextId.map(Type.Var.Skolem(tv.name, kind, _)) - /** - * See if the meta variable has been set with a Tau - * type - */ + /** See if the meta variable has been set with a Tau type + */ def readMeta(m: Type.Meta): Infer[Option[Type.Tau]] = lift(m.ref.get) - /** - * Set the meta variable to point to a Tau type - */ + /** Set the meta variable to point to a Tau type + */ private def writeMeta(m: Type.Meta, v: Type.Tau): Infer[Unit] = lift(m.ref.set(Some(v))) // DEEP-SKOL rule // note, this is identical to subsCheckRho when declared is a Rho type - def subsCheck(inferred: Type, declared: Type, left: Region, right: Region): Infer[TypedExpr.Coerce] = + def subsCheck( + inferred: Type, + declared: Type, + left: Region, + right: Region + ): Infer[TypedExpr.Coerce] = for { skolRho <- skolemize(declared, right) (skolTvs, rho2) = skolRho @@ -745,54 +928,70 @@ object Infer { res <- NonEmptyList.fromList(skolTvs) match { case None => pure(coerce) case Some(nel) => - getFreeTyVars(inferred :: declared :: Nil).flatMap { escTvs => - NonEmptyList.fromList(skolTvs.filter(escTvs)) match { - case None => pure(coerce.andThen(unskolemize(nel))) - case Some(badTvs) => fail(Error.SubsumptionCheckFailure(inferred, declared, left, right, badTvs)) - } - } - } + getFreeTyVars(inferred :: declared :: Nil).flatMap { escTvs => + NonEmptyList.fromList(skolTvs.filter(escTvs)) match { + case None => pure(coerce.andThen(unskolemize(nel))) + case Some(badTvs) => + fail( + Error.SubsumptionCheckFailure( + inferred, + declared, + left, + right, + badTvs + ) + ) + } + } + } } yield res - /** - * Invariant: if the second argument is (Check rho) then rho is in weak prenex form - */ - def typeCheckRho[A: HasRegion](term: Expr[A], expect: Expected[(Type.Rho, Region)]): Infer[TypedExpr.Rho[A]] = { + /** Invariant: if the second argument is (Check rho) then rho is in weak + * prenex form + */ + def typeCheckRho[A: HasRegion]( + term: Expr[A], + expect: Expected[(Type.Rho, Region)] + ): Infer[TypedExpr.Rho[A]] = { import Expr._ term match { case Literal(lit, t) => val tpe = Type.getTypeOf(lit) - instSigma(tpe, expect, region(term)).map(_(TypedExpr.Literal(lit, tpe, t))) + instSigma(tpe, expect, region(term)).map( + _(TypedExpr.Literal(lit, tpe, t)) + ) case Local(name, tag) => for { vSigma <- lookupVarType((None, name), region(term)) coerce <- instSigma(vSigma, expect, region(term)) - } yield coerce(TypedExpr.Local(name, vSigma, tag)) + } yield coerce(TypedExpr.Local(name, vSigma, tag)) case Global(pack, name, tag) => for { vSigma <- lookupVarType((Some(pack), name), region(term)) coerce <- instSigma(vSigma, expect, region(term)) - } yield coerce(TypedExpr.Global(pack, name, vSigma, tag)) + } yield coerce(TypedExpr.Global(pack, name, vSigma, tag)) case App(fn, args, tag) => - for { - typedFnTpe <- inferRho(fn) - (typedFn, fnTRho) = typedFnTpe - argsRegion = args.reduceMap(region[Expr[A]](_)) - argRes <- unifyFn(args.length, fnTRho, region(fn), argsRegion) - (argT, resT) = argRes - typedArg <- args.zip(argT).traverse { case (arg, argT) => checkSigma(arg, argT) } - coerce <- instSigma(resT, expect, region(term)) - } yield coerce(TypedExpr.App(typedFn, typedArg, resT, tag)) + for { + typedFnTpe <- inferRho(fn) + (typedFn, fnTRho) = typedFnTpe + argsRegion = args.reduceMap(region[Expr[A]](_)) + argRes <- unifyFn(args.length, fnTRho, region(fn), argsRegion) + (argT, resT) = argRes + typedArg <- args.zip(argT).traverse { case (arg, argT) => + checkSigma(arg, argT) + } + coerce <- instSigma(resT, expect, region(term)) + } yield coerce(TypedExpr.App(typedFn, typedArg, resT, tag)) case Generic(tpes, in) => - for { - (skols, t1) <- Expr.skolemizeVars(tpes, in)(newSkolemTyVar(_, _)) - sigmaT <- inferSigma(t1) - z <- zonkTypedExpr(sigmaT) - unSkol = unskolemize(skols)(z) - // unSkol is not a Rho type, we need instantiate it - coerce <- instSigma(unSkol.getType, expect, region(term)) - } yield coerce(unSkol) + for { + (skols, t1) <- Expr.skolemizeVars(tpes, in)(newSkolemTyVar(_, _)) + sigmaT <- inferSigma(t1) + z <- zonkTypedExpr(sigmaT) + unSkol = unskolemize(skols)(z) + // unSkol is not a Rho type, we need instantiate it + coerce <- instSigma(unSkol.getType, expect, region(term)) + } yield coerce(unSkol) case Lambda(args, result, tag) => expect match { case Expected.Check((expTy, rr)) => @@ -801,26 +1000,30 @@ object Infer { // we know expTy is in weak-prenex form, and since Fn is covariant, bodyT must be // in weak prenex form (varsT, bodyT) = vb - bodyTRho <- assertRho(bodyT, s"expect a rho type in $vb from $expTy at $rr", region(result)) + bodyTRho <- assertRho( + bodyT, + s"expect a rho type in $vb from $expTy at $rr", + region(result) + ) // the length of args and varsT must be the same because of unifyFn zipped = args.zip(varsT) namesVarsT = zipped.map { case ((n, _), t) => (n, t) } typedBody <- extendEnvList(namesVarsT.toList) { - // TODO we are ignoring the result of subsCheck here - // should we be coercing a var? - // - // this comes from page 54 of the paper, but I can't seem to find examples - // where this will fail if we reverse (as we had for a long time), which - // indicates the testing coverage is incomplete - zipped.traverse_ { - case ((_, Some(tpe)), varT) => - subsCheck(varT, tpe, region(term), rr) - case ((_, None), _) => unit - } *> + // TODO we are ignoring the result of subsCheck here + // should we be coercing a var? + // + // this comes from page 54 of the paper, but I can't seem to find examples + // where this will fail if we reverse (as we had for a long time), which + // indicates the testing coverage is incomplete + zipped.traverse_ { + case ((_, Some(tpe)), varT) => + subsCheck(varT, tpe, region(term), rr) + case ((_, None), _) => unit + } *> checkRho(result, bodyTRho) - } + } } yield TypedExpr.AnnotatedLambda(namesVarsT, typedBody, tag) - case infer@Expected.Inf(_) => + case infer @ Expected.Inf(_) => for { nameVarsT <- args.traverse { case (n, Some(tpe)) => @@ -830,11 +1033,15 @@ object Infer { // all functions args of kind type newMetaType(Kind.Type).map((n, _)) } - typedBodyTpe <- extendEnvList(nameVarsT.toList)(inferRho(result)) + typedBodyTpe <- extendEnvList(nameVarsT.toList)( + inferRho(result) + ) (typedBody, bodyT) = typedBodyTpe - _ <- infer.set((Type.Fun(nameVarsT.map(_._2), bodyT), region(term))) + _ <- infer.set( + (Type.Fun(nameVarsT.map(_._2), bodyT), region(term)) + ) } yield TypedExpr.AnnotatedLambda(nameVarsT, typedBody, tag) - } + } case Let(name, rhs, body, isRecursive, tag) => if (isRecursive.isRecursive) { // all defs are marked at potentially recursive. @@ -847,10 +1054,10 @@ object Infer { // compilers/evaluation can possibly optimize non-recursive // cases differently val rhsBody = rhs match { - case Annotation(expr, tpe, tag) => - extendEnv(name, tpe) { - checkSigma(expr, tpe).product(typeCheckRho(body, expect)) - } + case Annotation(expr, tpe, tag) => + extendEnv(name, tpe) { + checkSigma(expr, tpe).product(typeCheckRho(body, expect)) + } case _ => newMetaType(Kind.Type) // the kind of a let value is a Type .flatMap { rhsTpe => @@ -858,15 +1065,20 @@ object Infer { for { // the type variable needs to be unified with varT // note, varT could be a sigma type, it is not a Tau or Rho - typedRhs <- inferSigmaMeta(rhs, Some((name, rhsTpe, region(rhs)))) + typedRhs <- inferSigmaMeta( + rhs, + Some((name, rhsTpe, region(rhs))) + ) varT = typedRhs.getType // we need to overwrite the metavariable now with the full type - typedBody <- extendEnv(name, varT)(typeCheckRho(body, expect)) + typedBody <- extendEnv(name, varT)( + typeCheckRho(body, expect) + ) } yield (typedRhs, typedBody) } } - } - + } + rhsBody.map { case (rhs, body) => // TODO: a more efficient algorithm would do this top down // for each top level TypedExpr and build it bottom up. @@ -875,8 +1087,7 @@ object Infer { val isRecursive = RecursionKind.recursive(frees.contains(name)) TypedExpr.Let(name, rhs, body, isRecursive, tag) } - } - else { + } else { // In this branch, we typecheck the rhs *without* name in the environment // so any recursion in this case won't typecheck, and shadowing rules are // in place @@ -917,7 +1128,7 @@ object Infer { checkBranch(p, check, r, resT) } } yield TypedExpr.Match(tsigma, tbranches, tag) - case infer@Expected.Inf(_) => + case infer @ Expected.Inf(_) => for { tbranches <- branches.traverse { case (p, r) => inferBranch(p, check, r) @@ -930,9 +1141,13 @@ object Infer { } } - def narrowBranches[A: HasRegion](branches: NonEmptyList[(Pattern, (TypedExpr.Rho[A], Type.Rho))]): Infer[(Type.Rho, Region, NonEmptyList[(Pattern, TypedExpr.Rho[A])])] = { + def narrowBranches[A: HasRegion]( + branches: NonEmptyList[(Pattern, (TypedExpr.Rho[A], Type.Rho))] + ): Infer[(Type.Rho, Region, NonEmptyList[(Pattern, TypedExpr.Rho[A])])] = { - def minBy[M[_]: Monad, B](head: B, tail: List[B])(lteq: (B, B) => M[Boolean]): M[B] = + def minBy[M[_]: Monad, B](head: B, tail: List[B])( + lteq: (B, B) => M[Boolean] + ): M[B] = tail match { case Nil => Monad[M].pure(head) case h :: tail => @@ -941,9 +1156,12 @@ object Infer { val next = if (keep) head else h minBy(next, tail)(lteq) } - } + } - def ltEq[K](left: (TypedExpr[A], K), right: (TypedExpr[A], K)): Infer[Boolean] = { + def ltEq[K]( + left: (TypedExpr[A], K), + right: (TypedExpr[A], K) + ): Infer[Boolean] = { val leftTE = left._1 val rightTE = right._1 val lt = leftTE.getType @@ -951,14 +1169,12 @@ object Infer { val rt = rightTE.getType val rr = region(rightTE) // right <= left if left subsumes right - subsCheck(lt, rt, lr, rr) - .peek + subsCheck(lt, rt, lr, rr).peek .flatMap { case Right(_) => pure(true) - case Left(_) => + case Left(_) => // maybe the other way around - subsCheck(rt, lt, rr, lr) - .peek + subsCheck(rt, lt, rr, lr).peek .flatMap { case Right(_) => // okay, we see right > left @@ -970,20 +1186,24 @@ object Infer { } } - val withIdx = branches.zipWithIndex.map { case ((p, (te, tpe)), idx) => (te, (p, tpe, idx)) } + val withIdx = branches.zipWithIndex.map { case ((p, (te, tpe)), idx) => + (te, (p, tpe, idx)) + } for { - (minRes, (minPat, resTRho, minIdx)) <- minBy(withIdx.head, withIdx.tail)((a, b) => ltEq(a, b)) + (minRes, (minPat, resTRho, minIdx)) <- minBy( + withIdx.head, + withIdx.tail + )((a, b) => ltEq(a, b)) resRegion = region(minRes) resBranches <- withIdx.traverse { case (te, (p, tpe, idx)) => if (idx != minIdx) { // unfortunately we have to check each branch again to get the correct coerce subsCheckRho2(resTRho, tpe, resRegion, region(te)) .map { coerce => - (p, coerce(te)) + (p, coerce(te)) } - } - else pure((p, te)) + } else pure((p, te)) } } yield (resTRho, resRegion, resBranches) } @@ -991,14 +1211,23 @@ object Infer { /* * we require resT in weak prenex form because we call checkRho with it */ - def checkBranch[A: HasRegion](p: Pattern, sigma: Expected.Check[(Type, Region)], res: Expr[A], resT: Type.Rho): Infer[(Pattern, TypedExpr.Rho[A])] = + def checkBranch[A: HasRegion]( + p: Pattern, + sigma: Expected.Check[(Type, Region)], + res: Expr[A], + resT: Type.Rho + ): Infer[(Pattern, TypedExpr.Rho[A])] = for { patBind <- typeCheckPattern(p, sigma, region(res)) (pattern, bindings) = patBind tres <- extendEnvList(bindings)(checkRho(res, resT)) } yield (pattern, tres) - def inferBranch[A: HasRegion](p: Pattern, sigma: Expected.Check[(Type, Region)], res: Expr[A]): Infer[(Pattern, (TypedExpr.Rho[A], Type.Rho))] = + def inferBranch[A: HasRegion]( + p: Pattern, + sigma: Expected.Check[(Type, Region)], + res: Expr[A] + ): Infer[(Pattern, (TypedExpr.Rho[A], Type.Rho))] = for { patBind <- typeCheckPattern(p, sigma, region(res)) (pattern, bindings) = patBind @@ -1006,13 +1235,16 @@ object Infer { res <- extendEnvList(bindings)(inferRho(res)) } yield (pattern, res) - /** - * patterns can be a sigma type, not neccesarily a rho/tau - * return a list of bound names and their (sigma) types - * - * TODO: Pattern needs to have a region for each part - */ - def typeCheckPattern(pat: Pattern, sigma: Expected.Check[(Type, Region)], reg: Region): Infer[(Pattern, List[(Bindable, Type)])] = + /** patterns can be a sigma type, not neccesarily a rho/tau return a list of + * bound names and their (sigma) types + * + * TODO: Pattern needs to have a region for each part + */ + def typeCheckPattern( + pat: Pattern, + sigma: Expected.Check[(Type, Region)], + reg: Region + ): Infer[(Pattern, List[(Bindable, Type)])] = pat match { case GenPattern.WildCard => Infer.pure((pat, Nil)) case GenPattern.Literal(lit) => @@ -1032,7 +1264,8 @@ object Infer { def inner(pat: Pattern) = sigma match { case Expected.Check((t, _)) => - val res = (GenPattern.Annotation(GenPattern.Named(n, pat), t), t) + val res = + (GenPattern.Annotation(GenPattern.Named(n, pat), t), t) Infer.pure(res) } // We always return an annotation here, which is the only @@ -1048,8 +1281,8 @@ object Infer { val check = sigma match { case Expected.Check((t, tr)) => subsCheck(tpe, t, reg, tr) } - val names = items.collect { - case GenPattern.StrPart.NamedStr(n) => (n, tpe) + val names = items.collect { case GenPattern.StrPart.NamedStr(n) => + (n, tpe) } // we need to apply the type so the names are well typed val anpat = GenPattern.Annotation(pat, tpe) @@ -1062,25 +1295,34 @@ object Infer { * of them have type A. */ def checkItem( - inner: Type, - lst: Type, - e: ListPart[Pattern]): Infer[(ListPart[Pattern], List[(Bindable, Type)])] = - e match { - case l@ListPart.WildList => - // this is *a pattern that has list type, and binds that type to the name - Infer.pure((l, Nil)) - case l@ListPart.NamedList(splice) => - // this is *a pattern that has list type, and binds that type to the name - Infer.pure((l, (splice, lst) :: Nil)) - case ListPart.Item(p) => - // This is a non-splice - checkPat(p, inner, reg).map { case (p, l) => (ListPart.Item(p), l) } - } + inner: Type, + lst: Type, + e: ListPart[Pattern] + ): Infer[(ListPart[Pattern], List[(Bindable, Type)])] = + e match { + case l @ ListPart.WildList => + // this is *a pattern that has list type, and binds that type to the name + Infer.pure((l, Nil)) + case l @ ListPart.NamedList(splice) => + // this is *a pattern that has list type, and binds that type to the name + Infer.pure((l, (splice, lst) :: Nil)) + case ListPart.Item(p) => + // This is a non-splice + checkPat(p, inner, reg).map { case (p, l) => + (ListPart.Item(p), l) + } + } val tpeOfList: Infer[Type] = sigma.value match { case (Type.TyApply(Type.ListType, item), _) => pure(item) - case (Type.ForAll(b@NonEmptyList(_, Nil), Type.TyApply(Type.ListType, item)), _) => + case ( + Type.ForAll( + b @ NonEmptyList(_, Nil), + Type.TyApply(Type.ListType, item) + ), + _ + ) => // list is covariant so we can push down pure(Type.forAll(b, item)) case (_, reg) => @@ -1097,7 +1339,10 @@ object Infer { inners <- items.traverse(checkItem(tpeA, listA, _)) innerPat = inners.map(_._1) innerBinds = inners.flatMap(_._2) - } yield (GenPattern.Annotation(GenPattern.ListPat(innerPat), listA), innerBinds) + } yield ( + GenPattern.Annotation(GenPattern.ListPat(innerPat), listA), + innerBinds + ) case GenPattern.Annotation(p, tpe) => // like in the case of an annotation, we check the type, then @@ -1115,22 +1360,29 @@ object Infer { // if the pattern arity does not match the arity of the constructor // but we don't want to error type-checking since we want to show // the maximimum number of errors to the user - envs <- args.zip(params).traverse { case (p, t) => checkPat(p, t, reg) } + envs <- args.zip(params).traverse { case (p, t) => + checkPat(p, t, reg) + } pats = envs.map(_._1) bindings = envs.map(_._2) } yield (GenPattern.PositionalStruct(nm, pats), bindings.flatten) - case u@GenPattern.Union(h, t) => - (typeCheckPattern(h, sigma, reg), t.traverse(typeCheckPattern(_, sigma, reg))) - .mapN { case ((h, binds), neList) => - val pat = GenPattern.Union(h, neList.map(_._1)) - val allBinds = NonEmptyList(binds, (neList.map(_._2).toList)) - identicalBinds(u, allBinds, reg).as((pat, binds)) - } - .flatten + case u @ GenPattern.Union(h, t) => + ( + typeCheckPattern(h, sigma, reg), + t.traverse(typeCheckPattern(_, sigma, reg)) + ).mapN { case ((h, binds), neList) => + val pat = GenPattern.Union(h, neList.map(_._1)) + val allBinds = NonEmptyList(binds, (neList.map(_._2).toList)) + identicalBinds(u, allBinds, reg).as((pat, binds)) + }.flatten } // Unions have to have identical bindings in all branches - def identicalBinds(u: Pattern, binds: NonEmptyList[List[(Bindable, Type)]], reg: Region): Infer[Unit] = + def identicalBinds( + u: Pattern, + binds: NonEmptyList[List[(Bindable, Type)]], + reg: Region + ): Infer[Unit] = binds.map(_.map(_._1)) match { case nel @ NonEmptyList(h, t) => val bs = h.toSet @@ -1146,29 +1398,50 @@ object Infer { unifyType(tpe, tpe2, reg, reg) } } - } - else fail(Error.UnionPatternBindMismatch(u, nel, reg)) + } else fail(Error.UnionPatternBindMismatch(u, nel, reg)) } // TODO: we should be able to derive a region for any pattern - def checkPat(pat: Pattern, sigma: Type, reg: Region): Infer[(Pattern, List[(Bindable, Type)])] = + def checkPat( + pat: Pattern, + sigma: Type, + reg: Region + ): Infer[(Pattern, List[(Bindable, Type)])] = typeCheckPattern(pat, Expected.Check((sigma, reg)), reg) - def checkPatSigma(tpe: Type, exp: Expected.Check[(Type, Region)], sRegion: Region): Infer[Unit] = + def checkPatSigma( + tpe: Type, + exp: Expected.Check[(Type, Region)], + sRegion: Region + ): Infer[Unit] = exp match { - case Expected.Check((texp, tr)) => subsCheck(texp, tpe, tr, sRegion).void // this unit does not seem right + case Expected.Check((texp, tr)) => + subsCheck( + texp, + tpe, + tr, + sRegion + ).void // this unit does not seem right } - /** - * To do this, Infer will need to know the names of the type - * constructors in scope. - * - * Instantiation fills in all - */ - def instDataCon(consName: (PackageName, Constructor), sigma: Type, reg: Region, sigmaRegion: Region): Infer[List[Type]] = + /** To do this, Infer will need to know the names of the type constructors + * in scope. + * + * Instantiation fills in all + */ + def instDataCon( + consName: (PackageName, Constructor), + sigma: Type, + reg: Region, + sigmaRegion: Region + ): Infer[List[Type]] = GetDataCons(consName, reg).flatMap { case (args, consParams, tpeName) => val thisTpe = Type.TyConst(tpeName) - def loop(revArgs: List[(Type.Var.Bound, Kind.Arg)], leftKind: Kind, sigma: Type): Infer[Map[Type.Var, Type]] = + def loop( + revArgs: List[(Type.Var.Bound, Kind.Arg)], + leftKind: Kind, + sigma: Type + ): Infer[Map[Type.Var, Type]] = (revArgs, sigma) match { case (Nil, tpe) => for { @@ -1179,10 +1452,17 @@ object Infer { case ((v0, k) :: vs, Type.TyApply(left, right)) => for { rk <- kindOf(right, sigmaRegion) - _ <- unifyKind(k.kind, Type.TyVar(v0), rk, right, reg, sigmaRegion) + _ <- unifyKind( + k.kind, + Type.TyVar(v0), + rk, + right, + reg, + sigmaRegion + ) rest <- loop(vs, Kind.Cons(k, leftKind), left) } yield rest.updated(v0, right) - case (_, fa@Type.ForAll(_, _)) => + case (_, fa @ Type.ForAll(_, _)) => // we have to instantiate a rho type instantiate(fa, sigmaRegion).flatMap(loop(revArgs, leftKind, _)) case ((v0, k) :: rest, _) => @@ -1190,9 +1470,21 @@ object Infer { for { left <- newMetaType(Kind.Cons(k, leftKind)) right <- newMetaType(k.kind) - _ <- unifyType(Type.TyApply(left, right), sigma, reg, sigmaRegion) + _ <- unifyType( + Type.TyApply(left, right), + sigma, + reg, + sigmaRegion + ) sigmaKind <- kindOf(sigma, sigmaRegion) - _ <- unifyKind(leftKind, Type.TyVar(v0), sigmaKind, sigma, reg, sigmaRegion) + _ <- unifyKind( + leftKind, + Type.TyVar(v0), + sigmaKind, + sigma, + reg, + sigmaRegion + ) nextKind = Kind.Cons(k, leftKind) rest <- loop(rest, nextKind, left) } yield rest.updated(v0, right) @@ -1201,45 +1493,60 @@ object Infer { // so we push the forall down to avoid allocating a metaVar which can only // hold a monotype def pushDownCovariant( - revArgs: List[(Type.Var.Bound, Kind.Arg)], - revForAlls: List[(Type.Var.Bound, Kind)], - sigma: Type): Type = { - (revArgs, sigma) match { + revArgs: List[(Type.Var.Bound, Kind.Arg)], + revForAlls: List[(Type.Var.Bound, Kind)], + sigma: Type + ): Type = { + (revArgs, sigma) match { case (_, Type.ForAll(params, over)) => - pushDownCovariant(revArgs, params.toList reverse_::: revForAlls, over) - case ((_, Kind.Arg(Variance.Covariant, _)) :: rest, Type.TyApply(left, right)) => - // TODO Phantom variance has some special rules too. I guess we - // can push into phantom as well (though that's rare) - val leftFree = Type.freeBoundTyVars(left :: Nil).toSet - val rightFree = Type.freeBoundTyVars(right :: Nil).toSet - - val (nextRFA, nextRight) = - revForAlls.filter { case (leftA, _) => rightFree(leftA) && !leftFree(leftA) } match { - case Nil => (revForAlls, right) - case pushed => - // it is safe to push it down - val pushedSet = pushed.iterator.map(_._1).toSet - val revFA1 = revForAlls.toList.filterNot { case (b, _) => pushedSet(b) } - val pushedRight = Type.forAll(pushed.reverse, right) - (revFA1, pushedRight) - } - pushDownCovariant(rest, nextRFA, left) match { - case Type.ForAll(bs, l) => - Type.forAll(bs, Type.TyApply(l, nextRight)) - case rho: Type.Rho => - Type.TyApply(rho, nextRight) + pushDownCovariant( + revArgs, + params.toList reverse_::: revForAlls, + over + ) + case ( + (_, Kind.Arg(Variance.Covariant, _)) :: rest, + Type.TyApply(left, right) + ) => + // TODO Phantom variance has some special rules too. I guess we + // can push into phantom as well (though that's rare) + val leftFree = Type.freeBoundTyVars(left :: Nil).toSet + val rightFree = Type.freeBoundTyVars(right :: Nil).toSet + + val (nextRFA, nextRight) = + revForAlls.filter { case (leftA, _) => + rightFree(leftA) && !leftFree(leftA) + } match { + case Nil => (revForAlls, right) + case pushed => + // it is safe to push it down + val pushedSet = pushed.iterator.map(_._1).toSet + val revFA1 = revForAlls.toList.filterNot { case (b, _) => + pushedSet(b) + } + val pushedRight = Type.forAll(pushed.reverse, right) + (revFA1, pushedRight) } + pushDownCovariant(rest, nextRFA, left) match { + case Type.ForAll(bs, l) => + Type.forAll(bs, Type.TyApply(l, nextRight)) + case rho: Type.Rho => + Type.TyApply(rho, nextRight) + } case (_ :: rest, Type.TyApply(left, right)) => - val rightFree = Type.freeBoundTyVars(right :: Nil).toSet - val (keptRight, lefts) = - revForAlls.partition { case (leftA, _) => rightFree(leftA) } - - Type.forAll(keptRight.reverse, pushDownCovariant(rest, lefts, left)) match { - case Type.ForAll(bs, l) => - Type.forAll(bs, Type.TyApply(l, right)) - case rho: Type.Rho => - Type.TyApply(rho, right) - } + val rightFree = Type.freeBoundTyVars(right :: Nil).toSet + val (keptRight, lefts) = + revForAlls.partition { case (leftA, _) => rightFree(leftA) } + + Type.forAll( + keptRight.reverse, + pushDownCovariant(rest, lefts, left) + ) match { + case Type.ForAll(bs, l) => + Type.forAll(bs, Type.TyApply(l, right)) + case rho: Type.Rho => + Type.TyApply(rho, right) + } case _ => Type.forAll(revForAlls.reverse, sigma) } @@ -1255,7 +1562,10 @@ object Infer { def inferSigma[A: HasRegion](e: Expr[A]): Infer[TypedExpr[A]] = inferSigmaMeta(e, None) - def inferSigmaMeta[A: HasRegion](e: Expr[A], meta: Option[(Identifier, Type.TyMeta, Region)]): Infer[TypedExpr[A]] = { + def inferSigmaMeta[A: HasRegion]( + e: Expr[A], + meta: Option[(Identifier, Type.TyMeta, Region)] + ): Infer[TypedExpr[A]] = { def unifySelf(tpe: Type.Rho): Infer[Map[Name, Type]] = meta match { case None => getEnv @@ -1266,37 +1576,41 @@ object Infer { } } - /** - * if meta is Some, it is because it recursive, but those are almost - * always functions, so we can at least fix the arity of the function. - */ - val init: Infer[Unit] = - meta match { - case Some((_, tpe, rtpe)) => - def maybeUnified(e: Expr[A]): Infer[Unit] = - e match { - case Expr.Annotation(e1, t, _) => - unifyType(tpe, t, rtpe, region(e)) *> maybeUnified(e1) - case Expr.Lambda(args, res, _) => - unifyFn(args.length, tpe, rtpe, region(e) - region(res)).void - case _ => - // we just have to wait to infer - unit - } + /** if meta is Some, it is because it recursive, but those are almost + * always functions, so we can at least fix the arity of the function. + */ + val init: Infer[Unit] = + meta match { + case Some((_, tpe, rtpe)) => + def maybeUnified(e: Expr[A]): Infer[Unit] = + e match { + case Expr.Annotation(e1, t, _) => + unifyType(tpe, t, rtpe, region(e)) *> maybeUnified(e1) + case Expr.Lambda(args, res, _) => + unifyFn(args.length, tpe, rtpe, region(e) - region(res)).void + case _ => + // we just have to wait to infer + unit + } - maybeUnified(e) - case None => unit - } + maybeUnified(e) + case None => unit + } for { _ <- init rhoT <- inferRho(e) (rho, expTyRho) = rhoT envTys <- unifySelf(expTyRho) - q <- TypedExpr.quantify(envTys, rho, zonk(_), { (m, n) => - // quantify guarantees that the kind of n matches m - writeMeta(m, Type.TyVar(n)) - }) + q <- TypedExpr.quantify( + envTys, + rho, + zonk(_), + { (m, n) => + // quantify guarantees that the kind of n matches m + writeMeta(m, Type.TyVar(n)) + } + ) } yield q } @@ -1315,63 +1629,86 @@ object Infer { envTys <- getEnv escTvs <- getFreeTyVars(tpe :: envTys.values.toList) badTvs = skols.filter(escTvs) - _ <- require(badTvs.isEmpty, Error.NotPolymorphicEnough(tpe, t, NonEmptyList.fromListUnsafe(badTvs), region(t))) + _ <- require( + badTvs.isEmpty, + Error.NotPolymorphicEnough( + tpe, + t, + NonEmptyList.fromListUnsafe(badTvs), + region(t) + ) + ) // we need to zonk before we unskolemize because some of the metas could be skolems zte <- zonkTypedExpr(te) } yield unskolemize(neskols)(zte) } } yield te1 // should be fine since the everything after te is just checking - /** - * invariant: rho needs to be in weak-prenex form - */ - def checkRho[A: HasRegion](t: Expr[A], rho: Type.Rho): Infer[TypedExpr.Rho[A]] = + /** invariant: rho needs to be in weak-prenex form + */ + def checkRho[A: HasRegion]( + t: Expr[A], + rho: Type.Rho + ): Infer[TypedExpr.Rho[A]] = typeCheckRho(t, Expected.Check((rho, region(t)))) - /** - * recall a rho type never has a top level Forall - */ - def inferRho[A: HasRegion](t: Expr[A]): Infer[(TypedExpr.Rho[A], Type.Rho)] = + /** recall a rho type never has a top level Forall + */ + def inferRho[A: HasRegion]( + t: Expr[A] + ): Infer[(TypedExpr.Rho[A], Type.Rho)] = for { - ref <- initRef[(Type.Rho, Region)](Error.InferIncomplete("inferRho", t, region(t))) + ref <- initRef[(Type.Rho, Region)]( + Error.InferIncomplete("inferRho", t, region(t)) + ) expr <- typeCheckRho(t, Expected.Inf(ref)) // we don't need this ref, and it does not escape, so reset eitherTpe <- lift(ref.get <* ref.reset) tpe <- eitherTpe match { case Right(rho) => pure(rho._1) - case Left(err) => fail(err) + case Left(err) => fail(err) } } yield (expr, tpe) } - private def recursiveTypeCheck[A: HasRegion](name: Bindable, expr: Expr[A]): Infer[TypedExpr[A]] = + private def recursiveTypeCheck[A: HasRegion]( + name: Bindable, + expr: Expr[A] + ): Infer[TypedExpr[A]] = // values are of kind Type expr match { case Expr.Annotation(e, tpe, _) => extendEnv(name, tpe)(checkSigma(e, tpe)) case _ => newMetaType(Kind.Type).flatMap { tpe => - extendEnv(name, tpe)(typeCheckMeta(expr, Some((name, tpe, region(expr))))) + extendEnv(name, tpe)( + typeCheckMeta(expr, Some((name, tpe, region(expr)))) + ) } } - def typeCheck[A: HasRegion](t: Expr[A]): Infer[TypedExpr[A]] = typeCheckMeta(t, None) - private def unskolemize(skols: NonEmptyList[Type.Var.Skolem]): TypedExpr.Coerce = + private def unskolemize( + skols: NonEmptyList[Type.Var.Skolem] + ): TypedExpr.Coerce = new FunctionK[TypedExpr, TypedExpr] { def apply[A](te: TypedExpr[A]) = { // now replace the skols with generics val used = Type.tyVarBinders(te.getType :: Nil) val aligned = Type.alignBinders(skols, used) - val te2 = substTyExpr(skols, aligned.map { case (_, b) => Type.TyVar(b) }, te) + val te2 = + substTyExpr(skols, aligned.map { case (_, b) => Type.TyVar(b) }, te) // TODO: we have to not forget the skolem kinds TypedExpr.forAll(aligned.map { case (s, b) => (b, s.kind) }, te2) } } - private def typeCheckMeta[A: HasRegion](t: Expr[A], optMeta: Option[(Identifier, Type.TyMeta, Region)]): Infer[TypedExpr[A]] = { + private def typeCheckMeta[A: HasRegion]( + t: Expr[A], + optMeta: Option[(Identifier, Type.TyMeta, Region)] + ): Infer[TypedExpr[A]] = { def run(t: Expr[A]) = inferSigmaMeta(t, optMeta).flatMap(zonkTypedExpr _) val optSkols = t match { @@ -1394,32 +1731,40 @@ object Infer { def extendEnv[A](varName: Bindable, tpe: Type)(of: Infer[A]): Infer[A] = extendEnvList((varName, tpe) :: Nil)(of) - def extendEnvList[A](bindings: List[(Bindable, Type)])(of: Infer[A]): Infer[A] = + def extendEnvList[A](bindings: List[(Bindable, Type)])( + of: Infer[A] + ): Infer[A] = Infer.Impl.ExtendEnvs(bindings.map { case (n, t) => ((None, n), t) }, of) - private def extendEnvPack[A](pack: PackageName, name: Bindable, tpe: Type)(of: Infer[A]): Infer[A] = + private def extendEnvPack[A](pack: PackageName, name: Bindable, tpe: Type)( + of: Infer[A] + ): Infer[A] = Infer.Impl.ExtendEnvs(((Some(pack), name), tpe) :: Nil, of) - /** - * Packages are generally just lists of lets, this allows you to infer - * the scheme for each in the context of the list - */ - def typeCheckLets[A: HasRegion](pack: PackageName, ls: List[(Bindable, RecursionKind, Expr[A])]): Infer[List[(Bindable, RecursionKind, TypedExpr[A])]] = + /** Packages are generally just lists of lets, this allows you to infer the + * scheme for each in the context of the list + */ + def typeCheckLets[A: HasRegion]( + pack: PackageName, + ls: List[(Bindable, RecursionKind, Expr[A])] + ): Infer[List[(Bindable, RecursionKind, TypedExpr[A])]] = ls match { case Nil => Infer.pure(Nil) case (name, rec, expr) :: tail => for { - te <- if (rec.isRecursive) recursiveTypeCheck(name, expr) else typeCheck(expr) - rest <- extendEnvPack(pack, name, te.getType)(typeCheckLets(pack, tail)) + te <- + if (rec.isRecursive) recursiveTypeCheck(name, expr) + else typeCheck(expr) + rest <- extendEnvPack(pack, name, te.getType)( + typeCheckLets(pack, tail) + ) } yield (name, rec, te) :: rest } - /** - * This is useful to testing purposes. - * - * Given types a and b, can we substitute - * a for for b - */ + /** This is useful to testing purposes. + * + * Given types a and b, can we substitute a for for b + */ def substitutionCheck(a: Type, b: Type, ra: Region, rb: Region): Infer[Unit] = subsCheck(a, b, ra, rb).void } diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/ParsedTypeEnv.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/ParsedTypeEnv.scala index fc3c98747..954f9d202 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/ParsedTypeEnv.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/ParsedTypeEnv.scala @@ -4,11 +4,18 @@ import org.bykn.bosatsu.{PackageName, Identifier} import Identifier.Bindable -case class ParsedTypeEnv[+A](allDefinedTypes: List[DefinedType[A]], externalDefs: List[(PackageName, Bindable, Type)]) { +case class ParsedTypeEnv[+A]( + allDefinedTypes: List[DefinedType[A]], + externalDefs: List[(PackageName, Bindable, Type)] +) { def addDefinedType[A1 >: A](dt: DefinedType[A1]): ParsedTypeEnv[A1] = copy(allDefinedTypes = dt :: allDefinedTypes) - def addExternalValue(pn: PackageName, name: Bindable, tpe: Type): ParsedTypeEnv[A] = + def addExternalValue( + pn: PackageName, + name: Bindable, + tpe: Type + ): ParsedTypeEnv[A] = copy(externalDefs = (pn, name, tpe) :: externalDefs) } diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/Ref.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/Ref.scala index 81912c563..2028f9a3a 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/Ref.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/Ref.scala @@ -5,9 +5,8 @@ import cats.{StackSafeMonad, Eval} import scala.collection.mutable.{LongMap => MutableMap} import java.util.concurrent.atomic.AtomicLong -/** - * This gives a mutable reference in a monadic context - */ +/** This gives a mutable reference in a monadic context + */ sealed trait Ref[A] { def get: RefSpace[A] def set(a: A): RefSpace[Unit] @@ -41,7 +40,9 @@ object RefSpace { protected def runState(al: AtomicLong, state: State) = value } - private case class AllocRef[A](handle: Long, init: A) extends RefSpace[A] with Ref[A] { + private case class AllocRef[A](handle: Long, init: A) + extends RefSpace[A] + with Ref[A] { def get = this def set(a: A) = SetRef(handle, a) val reset = Reset(handle) @@ -57,26 +58,31 @@ object RefSpace { } } private case class SetRef(handle: Long, value: Any) extends RefSpace[Unit] { - protected def runState(al: AtomicLong, state: State): Eval[Unit] = - { state.put(handle, value); Eval.Unit } + protected def runState(al: AtomicLong, state: State): Eval[Unit] = { + state.put(handle, value); Eval.Unit + } } private case class Reset(handle: Long) extends RefSpace[Unit] { - protected def runState(al: AtomicLong, state: State): Eval[Unit] = - { state.remove(handle); Eval.Unit } + protected def runState(al: AtomicLong, state: State): Eval[Unit] = { + state.remove(handle); Eval.Unit + } } private case class Alloc[A](init: A) extends RefSpace[Ref[A]] { protected def runState(al: AtomicLong, state: State) = Eval.now(AllocRef(al.getAndIncrement, init)) } - private case class Map[A, B](init: RefSpace[A], fn: A => B) extends RefSpace[B] { + private case class Map[A, B](init: RefSpace[A], fn: A => B) + extends RefSpace[B] { protected def runState(al: AtomicLong, state: State) = Eval.defer(init.runState(al, state)).map(fn) } - private case class FlatMap[A, B](init: RefSpace[A], fn: A => RefSpace[B]) extends RefSpace[B] { + private case class FlatMap[A, B](init: RefSpace[A], fn: A => RefSpace[B]) + extends RefSpace[B] { protected def runState(al: AtomicLong, state: State): Eval[B] = - Eval.defer(init.runState(al, state)) + Eval + .defer(init.runState(al, state)) .flatMap { a => fn(a).runState(al, state) } @@ -95,23 +101,24 @@ object RefSpace { def put(key: Long, value: Any): Unit = discard(map.put(key, value)) def get(key: Long): Option[Any] = map.get(key) - def remove(key: Long): Unit = + def remove(key: Long): Unit = discard(map.remove(key)) } - private[RefSpace] class Fork(under: State, over: MutableMap[Option[Any]]) extends State { + private[RefSpace] class Fork(under: State, over: MutableMap[Option[Any]]) + extends State { def put(key: Long, value: Any): Unit = discard(over.put(key, Some(value))) def get(key: Long): Option[Any] = over.get(key) match { case Some(s) => s - case None => under.get(key) + case None => under.get(key) } def remove(key: Long): Unit = discard(over.put(key, None)) def flush(): Unit = { over.foreach { case (k, Some(v)) => under.put(k, v) - case (k, None) => under.remove(k) + case (k, None) => under.remove(k) } } @@ -122,16 +129,20 @@ object RefSpace { def fork(state: State): Fork = new Fork(state, MutableMap.empty) } - private case class ResetOnLeft[A, B, C](init: RefSpace[A], fn: A => Either[B, C]) extends RefSpace[Either[B, C]] { + private case class ResetOnLeft[A, B, C]( + init: RefSpace[A], + fn: A => Either[B, C] + ) extends RefSpace[Either[B, C]] { protected def runState(al: AtomicLong, state: State): Eval[Either[B, C]] = { val forked = State.fork(state) - init.runState(al, forked) + init + .runState(al, forked) .map { a => fn(a) match { - case r@Right(_) => + case r @ Right(_) => forked.flush() r - case l@Left(_) => + case l @ Left(_) => // just let the forked state disappear l } @@ -159,7 +170,8 @@ object RefSpace { // a counter that starts at 0 val allocCounter: RefSpace[RefSpace[Long]] = - RefSpace.newRef(0L) + RefSpace + .newRef(0L) .map { ref => for { a <- ref.get diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala index 01f078891..e1df7de96 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala @@ -4,7 +4,15 @@ import cats.data.NonEmptyList import cats.parse.{Parser => P, Numbers} import cats.{Applicative, Order} import org.typelevel.paiges.{Doc, Document} -import org.bykn.bosatsu.{Kind, PackageName, Lit, TypeName, Identifier, Parser, TypeParser} +import org.bykn.bosatsu.{ + Kind, + PackageName, + Lit, + TypeName, + Identifier, + Parser, + TypeParser +} import scala.collection.immutable.SortedSet import cats.implicits._ @@ -14,9 +22,9 @@ sealed abstract class Type { } object Type { - /** - * A type with no top level ForAll - */ + + /** A type with no top level ForAll + */ sealed abstract class Rho extends Type sealed abstract class Leaf extends Rho type Tau = Rho // no forall anywhere @@ -30,8 +38,7 @@ object Type { def sameType(left: Type, right: Type): Boolean = if (left.isInstanceOf[Leaf] && right.isInstanceOf[Leaf]) { left == right - } - else { + } else { normalize(left) == normalize(right) } @@ -50,19 +57,22 @@ object Type { val c = list.compare(v0.toList, v1.toList) if (c == 0) compare(i0, i1) else c case (ForAll(_, _), _) => -1 - case (TyConst(Const.Defined(p0, n0)), TyConst(Const.Defined(p1, n1))) => + case ( + TyConst(Const.Defined(p0, n0)), + TyConst(Const.Defined(p1, n1)) + ) => val c = Ordering[PackageName].compare(p0, p1) if (c == 0) Ordering[TypeName].compare(n0, n1) else c case (TyConst(_), ForAll(_, _)) => 1 - case (TyConst(_), _) => -1 + case (TyConst(_), _) => -1 case (TyVar(v0), TyVar(v1)) => Ordering[Var].compare(v0, v1) case (TyVar(_), ForAll(_, _) | TyConst(_)) => 1 - case (TyVar(_), _) => -1 + case (TyVar(_), _) => -1 case (TyMeta(Meta(_, i0, _)), TyMeta(Meta(_, i1, _))) => java.lang.Long.compare(i0, i1) case (TyMeta(_), TyApply(_, _)) => -1 - case (TyMeta(_), _) => 1 + case (TyMeta(_), _) => 1 case (TyApply(a0, b0), TyApply(a1, b1)) => val c = compare(a0, a1) if (c == 0) compare(b0, b1) else c @@ -85,7 +95,7 @@ object Type { def loop(fn: Type, acc: List[Type]): (Type, List[Type]) = fn match { case TyApply(fn, a) => loop(fn, a :: acc) - case notApply => (notApply, acc) + case notApply => (notApply, acc) } loop(fn, Nil) @@ -93,50 +103,48 @@ object Type { def constantsOf(t: Type): List[Const] = t match { - case ForAll(_, t) => constantsOf(t) - case TyApply(on, arg) => constantsOf(on) ::: constantsOf(arg) - case TyConst(c) => c :: Nil + case ForAll(_, t) => constantsOf(t) + case TyApply(on, arg) => constantsOf(on) ::: constantsOf(arg) + case TyConst(c) => c :: Nil case TyVar(_) | TyMeta(_) => Nil } def hasNoVars(t: Type): Boolean = t match { - case TyConst(_) => true + case TyConst(_) => true case TyVar(_) | TyMeta(_) => false - case TyApply(on, arg) => hasNoVars(on) && hasNoVars(arg) - case fa@ForAll(_, _) => freeTyVars(fa :: Nil).isEmpty + case TyApply(on, arg) => hasNoVars(on) && hasNoVars(arg) + case fa @ ForAll(_, _) => freeTyVars(fa :: Nil).isEmpty } final def forAll(vars: List[(Var.Bound, Kind)], in: Type): Type = NonEmptyList.fromList(vars) match { - case None => in + case None => in case Some(ne) => forAll(ne, in) } final def forAll(vars: NonEmptyList[(Var.Bound, Kind)], in: Type): Type = in match { - case rho: Rho => Type.ForAll(vars, rho) + case rho: Rho => Type.ForAll(vars, rho) case Type.ForAll(ne1, rho) => Type.ForAll(vars ::: ne1, rho) } def getTypeOf(lit: Lit): Type = lit match { case Lit.Integer(_) => Type.IntType - case Lit.Str(_) => Type.StrType + case Lit.Str(_) => Type.StrType } - /** - * types are var, meta, or const, or applied or forall on one of - * those. This returns the Type.TyConst found - * by recursing - */ + /** types are var, meta, or const, or applied or forall on one of those. This + * returns the Type.TyConst found by recursing + */ @annotation.tailrec final def rootConst(t: Type): Option[Type.TyConst] = t match { - case tyc@TyConst(_) => Some(tyc) + case tyc @ TyConst(_) => Some(tyc) case TyVar(_) | TyMeta(_) => None - case ForAll(_, r) => rootConst(r) - case TyApply(left, _) => rootConst(left) + case ForAll(_, r) => rootConst(r) + case TyApply(left, _) => rootConst(left) } object RootConst { @@ -149,15 +157,17 @@ object Type { def loop(t: Type, tail: List[Type]): (Type, List[Type]) = t match { case TyApply(left, right) => loop(left, right :: tail) - case notApply => (notApply, tail) + case notApply => (notApply, tail) } loop(t, Nil) } - /** - * This form is often useful in Infer - */ - def substTy(keys: NonEmptyList[Var], vals: NonEmptyList[Type]): Type => Type = { + /** This form is often useful in Infer + */ + def substTy( + keys: NonEmptyList[Var], + vals: NonEmptyList[Type] + ): Type => Type = { val env = keys.toList.iterator.zip(vals.toList.iterator).toMap { t => substituteVar(t, env) } @@ -165,66 +175,72 @@ object Type { def substituteVar(t: Type, env: Map[Type.Var, Type]): Type = t match { - case TyApply(on, arg) => TyApply(substituteVar(on, env), substituteVar(arg, env)) - case v@TyVar(n) => env.getOrElse(n, v) + case TyApply(on, arg) => + TyApply(substituteVar(on, env), substituteVar(arg, env)) + case v @ TyVar(n) => env.getOrElse(n, v) case ForAll(ns, rho) => val boundSet: Set[Var] = ns.toList.iterator.map(_._1).toSet val env1 = env.iterator.filter { case (v, _) => !boundSet(v) }.toMap forAll(ns.toList, substituteVar(rho, env1)) - case m@TyMeta(_) => m - case c@TyConst(_) => c + case m @ TyMeta(_) => m + case c @ TyConst(_) => c } def substituteRhoVar(t: Type.Rho, env: Map[Type.Var, Type.Rho]): Type.Rho = t match { - case TyApply(on, arg) => TyApply(substituteVar(on, env), substituteVar(arg, env)) - case v@TyVar(n) => env.getOrElse(n, v) - case m@TyMeta(_) => m - case c@TyConst(_) => c + case TyApply(on, arg) => + TyApply(substituteVar(on, env), substituteVar(arg, env)) + case v @ TyVar(n) => env.getOrElse(n, v) + case m @ TyMeta(_) => m + case c @ TyConst(_) => c } - /** - * Return the Bound and Skolem variables that - * are free in the given list of types - */ + /** Return the Bound and Skolem variables that are free in the given list of + * types + */ def freeTyVars(ts: List[Type]): List[Type.Var] = { // usually we can recurse in a loop, but sometimes not - def cheat(ts: List[Type], bound: Set[Type.Var.Bound], acc: List[Type.Var]): List[Type.Var] = + def cheat( + ts: List[Type], + bound: Set[Type.Var.Bound], + acc: List[Type.Var] + ): List[Type.Var] = go(ts, bound, acc) @annotation.tailrec - def go(ts: List[Type], bound: Set[Type.Var.Bound], acc: List[Type.Var]): List[Type.Var] = + def go( + ts: List[Type], + bound: Set[Type.Var.Bound], + acc: List[Type.Var] + ): List[Type.Var] = ts match { - case Nil => acc + case Nil => acc case Type.TyVar(tv) :: rest => // we only check here, we don't add val isBound = tv match { - case b@Type.Var.Bound(_) => bound(b) + case b @ Type.Var.Bound(_) => bound(b) case Type.Var.Skolem(_, _, _) => false } if (isBound) go(rest, bound, acc) else go(rest, bound, tv :: acc) case Type.TyApply(a, b) :: rest => go(a :: b :: rest, bound, acc) case Type.ForAll(tvs, ty) :: rest => - val acc1 = cheat(ty :: Nil, bound ++ tvs.toList.iterator.map(_._1), acc) + val acc1 = + cheat(ty :: Nil, bound ++ tvs.toList.iterator.map(_._1), acc) // note, tvs ARE NOT bound in rest go(rest, bound, acc1) case (Type.TyMeta(_) | Type.TyConst(_)) :: rest => go(rest, bound, acc) } - go(ts, Set.empty, Nil) - .reverse - .distinct + go(ts, Set.empty, Nil).reverse.distinct } - /** - * Return the Bound variables that - * are free in the given list of types - */ + /** Return the Bound variables that are free in the given list of types + */ def freeBoundTyVars(ts: List[Type]): List[Type.Var.Bound] = - freeTyVars(ts).collect { case b@Type.Var.Bound(_) => b } + freeTyVars(ts).collect { case b @ Type.Var.Bound(_) => b } def normalize(tpe: Type): Type = tpe match { @@ -239,39 +255,40 @@ object Type { if (vars2.tail.isEmpty) { // already sorted vars2 - } - else { + } else { // sort the quantification by the order of appearance val order = inFree.iterator.zipWithIndex.toMap vars2.sortBy { case (b, _) => order(b) } } val frees = freeBoundTyVars(tpe :: Nil).toSet val bs = alignBinders(vars, frees) - val subMap = bs.toList.map { case ((bold, _), bnew) => - bold -> TyVar(bnew) - } - .toMap[Type.Var, Type.Rho] + val subMap = bs.toList + .map { case ((bold, _), bnew) => + bold -> TyVar(bnew) + } + .toMap[Type.Var, Type.Rho] forAll( bs.toList.map { case ((_, k), b) => (b, k) }, - normalize(substituteRhoVar(in, subMap))) + normalize(substituteRhoVar(in, subMap)) + ) case None => normalize(in) } case TyApply(on, arg) => TyApply(normalize(on), normalize(arg)) - case _ => tpe + case _ => tpe } - /** - * These are upper-case to leverage scala's pattern - * matching on upper-cased vals - */ + + /** These are upper-case to leverage scala's pattern matching on upper-cased + * vals + */ val BoolType: Type.TyConst = TyConst(Const.predef("Bool")) val DictType: Type.TyConst = TyConst(Const.predef("Dict")) object FnType { final val MaxSize = 32 - private def predefFn(n: Int) = TyConst(Const.predef(s"Fn$n")) + private def predefFn(n: Int) = TyConst(Const.predef(s"Fn$n")) private val tpes = (1 to MaxSize).map(predefFn) object ValidArity { @@ -280,7 +297,10 @@ object Type { } def apply(n: Int): Type.TyConst = { - require(ValidArity.unapply(n), s"invalid FnType arity = $n, must be 0 < n <= $MaxSize") + require( + ValidArity.unapply(n), + s"invalid FnType arity = $n, must be 0 < n <= $MaxSize" + ) tpes(n - 1) } @@ -293,7 +313,8 @@ object Type { def unapply(tpe: Type): Option[(Type.TyConst, Int)] = { tpe match { - case Type.TyConst(Const.Predef(cons)) if (cons.asString.startsWith("Fn")) => + case Type.TyConst(Const.Predef(cons)) + if (cons.asString.startsWith("Fn")) => var idx = 0 while (idx < MaxSize) { val thisTpe = tpes(idx) @@ -311,11 +332,9 @@ object Type { def kindSize(n: Int): Kind = Kind((Vector.fill(n)(Kind.Type.contra) :+ Kind.Type.co): _*) - tpes - .iterator - .zipWithIndex - .map { case (t, n1) => (t, kindSize(n1 + 1)) } - .toList + tpes.iterator.zipWithIndex.map { case (t, n1) => + (t, kindSize(n1 + 1)) + }.toList } } val IntType: Type.TyConst = TyConst(Const.predef("Int")) @@ -334,10 +353,8 @@ object Type { ListType -> Kind(Kind.Type.co), StrType -> Kind.Type, UnitType -> Kind.Type, - TupleConsType -> Kind(Kind.Type.co, Kind.Type.co), - )) - .map { case (t, k) => (t.tpe.toDefined, k) } - .toMap + TupleConsType -> Kind(Kind.Type.co, Kind.Type.co) + )).map { case (t, k) => (t.tpe.toDefined, k) }.toMap def const(pn: PackageName, name: TypeName): Type = TyConst(Type.Const.Defined(pn, name)) @@ -351,7 +368,12 @@ object Type { } def unapply(t: Type): Option[(NonEmptyList[Type], Type)] = { - def check(n: Int, t: Type, applied: List[Type], last: Type): Option[(NonEmptyList[Type], Type)] = + def check( + n: Int, + t: Type, + applied: List[Type], + last: Type + ): Option[(NonEmptyList[Type], Type)] = t match { case TyApply(inner, arg) => check(n + 1, inner, arg :: applied, last) @@ -382,7 +404,7 @@ object Type { t match { case ForAll(_, t) => arity(t) case Fun(args, _) => args.length - case _ => 0 + case _ => 0 } } @@ -392,7 +414,7 @@ object Type { case UnitType => Some(Nil) case TyApply(TyApply(TupleConsType, h), t) => unapply(t) match { - case None => None + case None => None case Some(ts) => Some(h :: ts) } case _ => None @@ -411,7 +433,7 @@ object Type { def unapply(t: Type): Option[Type] = t match { case TyApply(OptionType, t) => Some(t) - case _ => None + case _ => None } } @@ -419,7 +441,7 @@ object Type { def unapply(t: Type): Option[(Type, Type)] = t match { case TyApply(TyApply(DictType, kt), vt) => Some((kt, vt)) - case _ => None + case _ => None } } @@ -427,7 +449,7 @@ object Type { def unapply(t: Type): Option[Type] = t match { case TyApply(ListType, t) => Some(t) - case _ => None + case _ => None } } @@ -446,7 +468,7 @@ object Type { def unapply(c: Const): Option[Identifier.Constructor] = c match { case Defined(PackageName.PredefName, TypeName(cons)) => Some(cons) - case _ => None + case _ => None } } } @@ -467,10 +489,8 @@ object Type { val c = str.charAt(0) if ('a' <= c && c <= 'z') { cache(c - 'a') - } - else new Bound(str) - } - else new Bound(str) + } else new Bound(str) + } else new Bound(str) } implicit val varOrdering: Ordering[Var] = @@ -478,7 +498,7 @@ object Type { def compare(a: Var, b: Var): Int = (a, b) match { case (Bound(a), Bound(b)) => a.compareTo(b) - case (Bound(_), _) => -1 + case (Bound(_), _) => -1 case (Skolem(n0, k0, i0), Skolem(n1, k1, i1)) => val c = java.lang.Long.compare(i0, i1) if (c != 0) c @@ -504,10 +524,15 @@ object Type { letters.map { c => Var.Bound(c.toString) } #::: lettersWithNumber } - def alignBinders[A](items: NonEmptyList[A], avoid: Set[Var.Bound]): NonEmptyList[(A, Var.Bound)] = { + def alignBinders[A]( + items: NonEmptyList[A], + avoid: Set[Var.Bound] + ): NonEmptyList[(A, Var.Bound)] = { val sz = items.size // for some reason on 2.11 we need to do .iterator or this will be an infinite loop - val bs = NonEmptyList.fromListUnsafe(allBinders.iterator.filterNot(avoid).take(sz).toList) + val bs = NonEmptyList.fromListUnsafe( + allBinders.iterator.filterNot(avoid).take(sz).toList + ) NonEmptyList((items.head, bs.head), items.tail.zip(bs.tail)) } @@ -518,26 +543,24 @@ object Type { Ordering.by { (m: Meta) => m.id } } - /** - * Final the set of all of Metas inside the list of given types - */ + /** Final the set of all of Metas inside the list of given types + */ def metaTvs(s: List[Type]): SortedSet[Meta] = { @annotation.tailrec def go(check: List[Type], acc: SortedSet[Meta]): SortedSet[Meta] = check match { - case Nil => acc - case ForAll(_, r) :: tail => go(r :: tail, acc) + case Nil => acc + case ForAll(_, r) :: tail => go(r :: tail, acc) case TyApply(a, r) :: tail => go(a :: r :: tail, acc) - case TyMeta(m) :: tail => go(tail, acc + m) - case _ :: tail => go(tail, acc) + case TyMeta(m) :: tail => go(tail, acc + m) + case _ :: tail => go(tail, acc) } go(s, SortedSet.empty) } - /** - * Report bound variables which are used in quantify. When we - * infer a sigma type - */ + /** Report bound variables which are used in quantify. When we infer a sigma + * type + */ def tyVarBinders(tpes: List[Type]): Set[Type.Var.Bound] = { @annotation.tailrec def loop(tpes: List[Type], acc: Set[Type.Var.Bound]): Set[Type.Var.Bound] = @@ -552,26 +575,28 @@ object Type { loop(tpes, Set.empty) } - /** - * Transform meta variables in some way - */ - def zonkMeta[F[_]: Applicative](t: Type)(m: Meta => F[Option[Type.Rho]]): F[Type] = + /** Transform meta variables in some way + */ + def zonkMeta[F[_]: Applicative]( + t: Type + )(m: Meta => F[Option[Type.Rho]]): F[Type] = t match { case rho: Rho => zonkRhoMeta(rho)(m).widen case ForAll(ns, ty) => zonkRhoMeta(ty)(m).map(Type.ForAll(ns, _)) } - /** - * Transform meta variables in some way - */ - def zonkRhoMeta[F[_]: Applicative](t: Type.Rho)(mfn: Meta => F[Option[Type.Rho]]): F[Type.Rho] = + /** Transform meta variables in some way + */ + def zonkRhoMeta[F[_]: Applicative]( + t: Type.Rho + )(mfn: Meta => F[Option[Type.Rho]]): F[Type.Rho] = t match { case Type.TyApply(on, arg) => (zonkMeta(on)(mfn), zonkMeta(arg)(mfn)).mapN(Type.TyApply(_, _)) - case t@Type.TyMeta(m) => + case t @ Type.TyMeta(m) => mfn(m).map { - case None => t + case None => t case Some(rho) => rho } case (Type.TyConst(_) | Type.TyVar(_)) => Applicative[F].pure(t) @@ -580,8 +605,11 @@ object Type { private object FullResolved extends TypeParser[Type] { lazy val parseRoot = { val tvar = Parser.lowerIdent.map { s => Type.TyVar(Type.Var.Bound(s)) } - val name = ((PackageName.parser <* P.string("::")) ~ Identifier.consParser) - .map { case (p, n) => Type.TyConst(Type.Const.Defined(p, TypeName(n))) } + val name = + ((PackageName.parser <* P.string("::")) ~ Identifier.consParser) + .map { case (p, n) => + Type.TyConst(Type.Const.Defined(p, TypeName(n))) + } val longParser: P[Long] = Numbers.signedIntString.mapFilter { str => try Some(str.toLong) catch { @@ -598,7 +626,9 @@ object Type { // the ideal solution is to better static type information // to have fully inferred types with no skolems or metas // TODO Kind - val meta = (P.char('?') *> longParser).map { l => TyMeta(Meta(Kind.Type, l, null)) } + val meta = (P.char('?') *> longParser).map { l => + TyMeta(Meta(Kind.Type, l, null)) + } tvar.orElse(name).orElse(skolem).orElse(meta) } @@ -607,13 +637,17 @@ object Type { // this may be an invalid function, but typechecking verifies that. Type.Fun(in, out) - def applyTypes(left: Type, args: NonEmptyList[Type]) = applyAll(left, args.toList) + def applyTypes(left: Type, args: NonEmptyList[Type]) = + applyAll(left, args.toList) def universal(vs: NonEmptyList[(String, Option[Kind])], on: Type) = - Type.forAll(vs.map { - case (s, None) => (Type.Var.Bound(s), Kind.Type) - case (s, Some(k)) => (Type.Var.Bound(s), k) - }, on) + Type.forAll( + vs.map { + case (s, None) => (Type.Var.Bound(s), Kind.Type) + case (s, Some(k)) => (Type.Var.Bound(s), k) + }, + on + ) def makeTuple(lst: List[Type]) = Type.Tuple(lst) @@ -622,8 +656,11 @@ object Type { def unapplyRoot(a: Type): Option[Doc] = a match { case TyConst(Const.Defined(p, n)) => - Some(Document[PackageName].document(p) + coloncolon + Document[Identifier].document(n.ident)) - case TyVar(Var.Bound(s)) => Some(Doc.text(s)) + Some( + Document[PackageName] + .document(p) + coloncolon + Document[Identifier].document(n.ident) + ) + case TyVar(Var.Bound(s)) => Some(Doc.text(s)) case TyVar(Var.Skolem(n, _, i)) => // TODO Kind val dol = "$" @@ -637,33 +674,40 @@ object Type { def unapplyFn(a: Type): Option[(NonEmptyList[Type], Type)] = a match { case Fun(as, b) => Some((as, b)) - case _ => None + case _ => None } - def unapplyUniversal(a: Type): Option[(List[(String, Option[Kind])], Type)] = + def unapplyUniversal( + a: Type + ): Option[(List[(String, Option[Kind])], Type)] = a match { case ForAll(vs, arg) => - Some((vs.map { - case (v, k) => (v.name, Some(k)) - }.toList, arg)) + Some( + ( + vs.map { case (v, k) => + (v.name, Some(k)) + }.toList, + arg + ) + ) case _ => None } def unapplyTypeApply(a: Type): Option[(Type, List[Type])] = a match { - case ta@TyApply(_, _) => Some(unapplyAll(ta)) - case _ => None + case ta @ TyApply(_, _) => Some(unapplyAll(ta)) + case _ => None } def unapplyTuple(a: Type): Option[List[Type]] = a match { case Tuple(as) => Some(as) - case _ => None + case _ => None } } - /** - * Parse fully resolved types: package::type - */ + + /** Parse fully resolved types: package::type + */ def fullyResolvedParser: P[Type] = FullResolved.parser def fullyResolvedDocument: Document[Type] = FullResolved.document def typeParser: TypeParser[Type] = FullResolved diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/TypeEnv.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/TypeEnv.scala index 5239aa6f2..f7cb2f1f3 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/TypeEnv.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/TypeEnv.scala @@ -5,9 +5,13 @@ import org.bykn.bosatsu.Identifier.{Bindable, Constructor} import scala.collection.immutable.SortedMap class TypeEnv[+A] private ( - protected val values: SortedMap[(PackageName, Identifier), Type], - protected val constructors: SortedMap[(PackageName, Constructor), (DefinedType[A], ConstructorFn)], - val definedTypes: SortedMap[(PackageName, TypeName), DefinedType[A]]) { + protected val values: SortedMap[(PackageName, Identifier), Type], + protected val constructors: SortedMap[ + (PackageName, Constructor), + (DefinedType[A], ConstructorFn) + ], + val definedTypes: SortedMap[(PackageName, TypeName), DefinedType[A]] +) { override def equals(that: Any): Boolean = that match { @@ -35,10 +39,16 @@ class TypeEnv[+A] private ( def allDefinedTypes: List[DefinedType[A]] = definedTypes.values.toList.sortBy { dt => (dt.packageName, dt.name) } - def getConstructor(p: PackageName, c: Constructor): Option[(DefinedType[A], ConstructorFn)] = + def getConstructor( + p: PackageName, + c: Constructor + ): Option[(DefinedType[A], ConstructorFn)] = constructors.get((p, c)) - def getConstructorParams(p: PackageName, c: Constructor): Option[List[(Bindable, Type)]] = + def getConstructorParams( + p: PackageName, + c: Constructor + ): Option[List[(Bindable, Type)]] = constructors.get((p, c)).map(_._2.args) def getType(p: PackageName, t: TypeName): Option[DefinedType[A]] = @@ -48,7 +58,9 @@ class TypeEnv[+A] private ( values.get((p, n)) // when we have resolved, we can get the types of constructors out - def getValue(p: PackageName, n: Identifier)(implicit ev: A <:< Kind.Arg): Option[Type] = + def getValue(p: PackageName, n: Identifier)(implicit + ev: A <:< Kind.Arg + ): Option[Type] = n match { case c @ Constructor(_) => // constructors are never external defs @@ -57,65 +69,82 @@ class TypeEnv[+A] private ( } // when we have resolved, we can get the types of constructors out - def localValuesOf(p: PackageName)(implicit ev: A <:< Kind.Arg): SortedMap[Identifier, Type] = { + def localValuesOf( + p: PackageName + )(implicit ev: A <:< Kind.Arg): SortedMap[Identifier, Type] = { val bldr = SortedMap.newBuilder[Identifier, Type] // add externals bldr ++= values.iterator.collect { case ((pn, n), v) if pn == p => (n, v) } // add constructors - bldr ++= constructors.iterator.collect { case ((pn, n), (dt, cf)) if pn == p => (n, dt.fnTypeOf(cf)) } + bldr ++= constructors.iterator.collect { + case ((pn, n), (dt, cf)) if pn == p => (n, dt.fnTypeOf(cf)) + } bldr.result() } - def addConstructor[A1 >: A](pack: PackageName, - dt: DefinedType[A1], - cf: ConstructorFn): TypeEnv[A1] = { - val nec = constructors.updated((pack, cf.name), (dt, cf)) - val dt1 = definedTypes.updated((dt.packageName, dt.name), dt) - new TypeEnv(values = values, constructors = nec, definedTypes = dt1) - } + def addConstructor[A1 >: A]( + pack: PackageName, + dt: DefinedType[A1], + cf: ConstructorFn + ): TypeEnv[A1] = { + val nec = constructors.updated((pack, cf.name), (dt, cf)) + val dt1 = definedTypes.updated((dt.packageName, dt.name), dt) + new TypeEnv(values = values, constructors = nec, definedTypes = dt1) + } - /** - * only add the type, do not add any of the constructors - * used when importing values - */ + /** only add the type, do not add any of the constructors used when importing + * values + */ def addDefinedType[A1 >: A](dt: DefinedType[A1]): TypeEnv[A1] = { val dt1 = definedTypes.updated((dt.packageName, dt.name), dt) - new TypeEnv(constructors = constructors, definedTypes = dt1, values = values) + new TypeEnv( + constructors = constructors, + definedTypes = dt1, + values = values + ) } - /** - * add a DefinedType and all of its constructors. This is done locally for - * a package - */ - def addDefinedTypeAndConstructors[A1 >: A](dt: DefinedType[A1]): TypeEnv[A1] = { + /** add a DefinedType and all of its constructors. This is done locally for a + * package + */ + def addDefinedTypeAndConstructors[A1 >: A]( + dt: DefinedType[A1] + ): TypeEnv[A1] = { val dt1 = definedTypes.updated((dt.packageName, dt.name), dt) val cons1 = dt.constructors - .foldLeft(constructors: SortedMap[(PackageName, Constructor), (DefinedType[A1], ConstructorFn)]) { - case (cons0, cf) => - cons0.updated((dt.packageName, cf.name), (dt, cf)) + .foldLeft( + constructors: SortedMap[ + (PackageName, Constructor), + (DefinedType[A1], ConstructorFn) + ] + ) { case (cons0, cf) => + cons0.updated((dt.packageName, cf.name), (dt, cf)) } new TypeEnv(constructors = cons1, definedTypes = dt1, values = values) } - /** - * External values cannot be inferred and have to be fully - * annotated - */ - def addExternalValue(pack: PackageName, name: Identifier, t: Type): TypeEnv[A] = + /** External values cannot be inferred and have to be fully annotated + */ + def addExternalValue( + pack: PackageName, + name: Identifier, + t: Type + ): TypeEnv[A] = new TypeEnv( constructors = constructors, definedTypes = definedTypes, - values = values.updated((pack, name), t)) + values = values.updated((pack, name), t) + ) - lazy val typeConstructors: SortedMap[(PackageName, Constructor), (List[(Type.Var.Bound, A)], List[Type], Type.Const.Defined)] = + lazy val typeConstructors: SortedMap[ + (PackageName, Constructor), + (List[(Type.Var.Bound, A)], List[Type], Type.Const.Defined) + ] = constructors.map { case (pc, (dt, cf)) => - (pc, - (dt.annotatedTypeParams, - cf.args.map(_._2), - dt.toTypeConst)) + (pc, (dt.annotatedTypeParams, cf.args.map(_._2), dt.toTypeConst)) } def definedTypeFor(c: (PackageName, Constructor)): Option[DefinedType[A]] = @@ -127,9 +156,11 @@ class TypeEnv[+A] private ( } def ++[A1 >: A](that: TypeEnv[A1]): TypeEnv[A1] = - new TypeEnv(values ++ that.values, + new TypeEnv( + values ++ that.values, constructors ++ that.constructors, - definedTypes ++ that.definedTypes) + definedTypes ++ that.definedTypes + ) def toKindMap(implicit ev: A <:< Kind.Arg): Map[Type.Const.Defined, Kind] = { type F[+Z] = List[DefinedType[Z]] @@ -142,17 +173,24 @@ object TypeEnv { val empty: TypeEnv[Nothing] = new TypeEnv( SortedMap.empty[(PackageName, Identifier), Type], - SortedMap.empty[(PackageName, Constructor), (DefinedType[Nothing], ConstructorFn)], - SortedMap.empty[(PackageName, TypeName), DefinedType[Nothing]]) - - /** - * Adds all the types and all the constructors from the given types - */ + SortedMap.empty[ + (PackageName, Constructor), + (DefinedType[Nothing], ConstructorFn) + ], + SortedMap.empty[(PackageName, TypeName), DefinedType[Nothing]] + ) + + /** Adds all the types and all the constructors from the given types + */ def fromDefinitions[A](defs: List[DefinedType[A]]): TypeEnv[A] = defs.foldLeft(empty: TypeEnv[A])(_.addDefinedTypeAndConstructors(_)) def fromParsed[A](p: ParsedTypeEnv[A]): TypeEnv[A] = { - val t1 = p.allDefinedTypes.foldLeft(empty: TypeEnv[A])(_.addDefinedTypeAndConstructors(_)) - p.externalDefs.foldLeft(t1) { case (t1, (p, n, t)) => t1.addExternalValue(p, n, t) } + val t1 = p.allDefinedTypes.foldLeft(empty: TypeEnv[A])( + _.addDefinedTypeAndConstructors(_) + ) + p.externalDefs.foldLeft(t1) { case (t1, (p, n, t)) => + t1.addExternalValue(p, n, t) + } } } diff --git a/core/src/test/scala/org/bykn/bosatsu/CollectionUtilsTest.scala b/core/src/test/scala/org/bykn/bosatsu/CollectionUtilsTest.scala index f812b4e62..6998905bb 100644 --- a/core/src/test/scala/org/bykn/bosatsu/CollectionUtilsTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/CollectionUtilsTest.scala @@ -1,13 +1,16 @@ package org.bykn.bosatsu -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.scalatest.funsuite.AnyFunSuite class CollectionUtilsTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 5000) + // PropertyCheckConfiguration(minSuccessful = 5000) PropertyCheckConfiguration(minSuccessful = 500) - //PropertyCheckConfiguration(minSuccessful = 5) + // PropertyCheckConfiguration(minSuccessful = 5) test("listToUnique works for maps converted to lists") { forAll { (m: Map[Int, Int]) => diff --git a/core/src/test/scala/org/bykn/bosatsu/DeclarationTest.scala b/core/src/test/scala/org/bykn/bosatsu/DeclarationTest.scala index f162145d5..b57f0893c 100644 --- a/core/src/test/scala/org/bykn/bosatsu/DeclarationTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/DeclarationTest.scala @@ -2,7 +2,10 @@ package org.bykn.bosatsu import cats.data.NonEmptyList import org.scalacheck.Gen -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import Identifier.Bindable @@ -14,27 +17,32 @@ class DeclarationTest extends AnyFunSuite { import Generators.shrinkDecl implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 5000) - PropertyCheckConfiguration(minSuccessful = if (Platform.isScalaJvm) 200 else 20) - //PropertyCheckConfiguration(minSuccessful = 50) + // PropertyCheckConfiguration(minSuccessful = 5000) + PropertyCheckConfiguration(minSuccessful = + if (Platform.isScalaJvm) 200 else 20 + ) + // PropertyCheckConfiguration(minSuccessful = 50) implicit val emptyRegion: Region = Region(0, 0) val genDecl = Generators.genDeclaration(depth = 4) lazy val genNonFree: Gen[Declaration.NonBinding] = - genDecl.flatMap { - case decl: Declaration.NonBinding if decl.freeVars.isEmpty => Gen.const(decl) - case _ => genNonFree - } - + genDecl.flatMap { + case decl: Declaration.NonBinding if decl.freeVars.isEmpty => + Gen.const(decl) + case _ => genNonFree + } test("freeVarsSet is a subset of allVars") { forAll(genDecl) { decl => val frees = decl.freeVars val av = decl.allNames val missing = frees -- av - assert(missing.isEmpty, s"expression:\n\n${decl}\n\nallVars: $av\n\nfrees: $frees") + assert( + missing.isEmpty, + s"expression:\n\n${decl}\n\nallVars: $av\n\nfrees: $frees" + ) } } @@ -53,8 +61,10 @@ class DeclarationTest extends AnyFunSuite { val d1Str = d1.toDoc.render(80) val dSubStr = d0sub.toDoc.render(80) - assert(!d0sub.freeVars.contains(b), - s"subs:\n\n$d0Str\n\n===============\n\n$d1Str===============\n\n$dSubStr") + assert( + !d0sub.freeVars.contains(b), + s"subs:\n\n$d0Str\n\n===============\n\n$d1Str===============\n\n$dSubStr" + ) } } } @@ -67,11 +77,11 @@ class DeclarationTest extends AnyFunSuite { lazy val notFree: Gen[Bindable] = Generators.bindIdentGen.flatMap { case b if frees(b) => notFree - case b => Gen.const(b) + case b => Gen.const(b) } notFree.map((decl, _)) - } + } def law(b: Bindable, d1: Declaration.NonBinding, d0: Declaration) = { val frees = d0.freeVars @@ -90,8 +100,7 @@ class DeclarationTest extends AnyFunSuite { // there must be some diff val diffPos = - left - .iterator + left.iterator .zip(right.iterator) .zipWithIndex .dropWhile { case ((a, b), _) => a == b } @@ -102,7 +111,8 @@ class DeclarationTest extends AnyFunSuite { val leftAt = left.drop(diffPos).take(50) val rightAt = right.drop(diffPos).take(50) val diff = s"offset: $diffPos$line$leftAt\n\n$line$rightAt" - val msg = s"left$line${left}\n\nright$line$right\n\ndiff$line$diff" + val msg = + s"left$line${left}\n\nright$line$right\n\ndiff$line$diff" assert(false, msg) } } @@ -123,10 +133,74 @@ class DeclarationTest extends AnyFunSuite { val b = Identifier.Backticked("") val d1 = Literal(Lit.fromInt(0)) val d0 = DefFn( - DefStatement(Name("mfLjwok"),None, NonEmptyList.one(NonEmptyList.one(Pattern.Var(Name("foo")))),None, - (NotSameLine(Padding(10,Indented(10,Var(Backticked(""))))), - Padding(10,Binding(BindingStatement( - Pattern.Var(Backticked("")),Var(Constructor("Rgt")),Padding(1,DefFn(DefStatement(Backticked(""),None,NonEmptyList.one(NonEmptyList.one(Pattern.Var(Name("bar")))),None,(NotSameLine(Padding(2,Indented(4,Literal(Lit.fromInt(42))))),Padding(2,DefFn(DefStatement(Name("gkxAckqpatu"),None, NonEmptyList.one(NonEmptyList.one(Pattern.Var(Name("quux")))),Some(TypeRef.TypeName(TypeName(Constructor("Y")))),(NotSameLine(Padding(6,Indented(8,Literal(Lit("oimsu"))))),Padding(2,Var(Name("j"))))))))))))))))) + DefStatement( + Name("mfLjwok"), + None, + NonEmptyList.one(NonEmptyList.one(Pattern.Var(Name("foo")))), + None, + ( + NotSameLine(Padding(10, Indented(10, Var(Backticked(""))))), + Padding( + 10, + Binding( + BindingStatement( + Pattern.Var(Backticked("")), + Var(Constructor("Rgt")), + Padding( + 1, + DefFn( + DefStatement( + Backticked(""), + None, + NonEmptyList.one( + NonEmptyList.one(Pattern.Var(Name("bar"))) + ), + None, + ( + NotSameLine( + Padding( + 2, + Indented(4, Literal(Lit.fromInt(42))) + ) + ), + Padding( + 2, + DefFn( + DefStatement( + Name("gkxAckqpatu"), + None, + NonEmptyList.one( + NonEmptyList.one( + Pattern.Var(Name("quux")) + ) + ), + Some( + TypeRef.TypeName( + TypeName(Constructor("Y")) + ) + ), + ( + NotSameLine( + Padding( + 6, + Indented(8, Literal(Lit("oimsu"))) + ) + ), + Padding(2, Var(Name("j"))) + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) + ) (b, d1, d0) } @@ -141,7 +215,7 @@ class DeclarationTest extends AnyFunSuite { genDecl.flatMap { decl => val frees = decl.freeVars.toList frees match { - case Nil => genFrees + case Nil => genFrees case nonEmpty => Gen.oneOf(nonEmpty).map((decl, _)) } } @@ -178,13 +252,17 @@ class DeclarationTest extends AnyFunSuite { val resD = res.map(unsafeParse(Declaration.parser(""), _)) val b = unsafeParse(Identifier.bindableParser, bStr) - assert(Declaration.substitute(b, d1.toNonBinding, d0) == resD) } - law("b", "12", """x = b -x""", Some("""x = 12 -x""")) + law( + "b", + "12", + """x = b +x""", + Some("""x = 12 +x""") + ) law("b", "12", """[x for b in y]""", Some("""[x for b in y]""")) law("b", "12", """[b for z in y]""", Some("""[12 for z in y]""")) @@ -209,8 +287,16 @@ x""")) law("[a for b in c if b]", List("a", "c"), List("a", "b", "c")) law("[b for b in c if d]", List("c", "d"), List("b", "c", "d")) law("[b for b in c if b]", List("c"), List("b", "c")) - law("{ k: a for b in c if d}", List("k", "a", "c", "d"), List("k", "a", "b", "c", "d")) - law("{ k: a for b in c if b}", List("k", "a", "c"), List("k", "a", "b", "c")) + law( + "{ k: a for b in c if d}", + List("k", "a", "c", "d"), + List("k", "a", "b", "c", "d") + ) + law( + "{ k: a for b in c if b}", + List("k", "a", "c"), + List("k", "a", "b", "c") + ) law("Foo { a }", List("a"), List("a")) law("Foo { a: b }", List("b"), List("b")) } diff --git a/core/src/test/scala/org/bykn/bosatsu/DefRecursionCheckTest.scala b/core/src/test/scala/org/bykn/bosatsu/DefRecursionCheckTest.scala index e6c4e6cba..840514408 100644 --- a/core/src/test/scala/org/bykn/bosatsu/DefRecursionCheckTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/DefRecursionCheckTest.scala @@ -19,7 +19,7 @@ class DefRecursionCheckTest extends AnyFunSuite { def disallowed(teStr: String) = { val stmt = TestUtils.statementsOf(teStr) stmt.traverse_(DefRecursionCheck.checkStatement(_)) match { - case Validated.Valid(_) => fail("expected failure") + case Validated.Valid(_) => fail("expected failure") case Validated.Invalid(_) => succeed } } @@ -376,7 +376,7 @@ def nest(lst): } test("we can't use an outer def recursively") { -disallowed("""# + disallowed("""# def foo(x): def bar(y): foo(y) @@ -385,7 +385,7 @@ def foo(x): } test("we can make a recursive def in another recursive def") { -allowed("""# + allowed("""# def len(lst): # this is doing nothing, but is a nested recursion def len0(lst): @@ -400,7 +400,7 @@ def len(lst): } test("we can call a non-outer function in a recur branch") { -allowed("""# + allowed("""# def id(x): x def len(lst): diff --git a/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala b/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala index a08e9156c..d03b1a86d 100644 --- a/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala @@ -15,7 +15,10 @@ class EvaluationTest extends AnyFunSuite with ParTest { package Foo x = 1 -"""), "Foo", VInt(1)) +"""), + "Foo", + VInt(1) + ) evalTest(List("x = 1"), "Package0", VInt(1)) @@ -23,7 +26,10 @@ x = 1 List(""" # test shadowing x = match 1: case x: x -"""), "Package0", VInt(1)) +"""), + "Package0", + VInt(1) + ) evalTest( List(""" @@ -31,7 +37,10 @@ package Foo # exercise calling directly a lambda x = (y -> y)("hello") -"""), "Foo", Str("hello")) +"""), + "Foo", + Str("hello") + ) runBosatsuTest( List(""" @@ -45,7 +54,10 @@ def eq_String(a, b): case _: False test = Assertion(eq_String("hello", foo), "checking equality") -"""), "Foo", 1) +"""), + "Foo", + 1 + ) runBosatsuTest( List(""" @@ -59,7 +71,10 @@ foo = ( ) test = Assertion(foo matches 4, "checking equality") -"""), "Foo", 1) +"""), + "Foo", + 1 + ) runBosatsuTest( List(""" @@ -69,7 +84,10 @@ test = TestSuite("three trivial tests", [ Assertion(True, "t0"), Assertion(True, "t1"), Assertion(True, "t2"), ]) -"""), "Foo", 3) +"""), + "Foo", + 3 + ) } test("test if/else") { @@ -84,7 +102,10 @@ z = match x.cmp_Int(1): "foo" case _: "bar" -"""), "Foo", Str("foo")) +"""), + "Foo", + Str("foo") + ) evalTest( List(""" @@ -94,7 +115,10 @@ x = 1 # here if the single expression python style z = "foo" if x.eq_Int(2) else "bar" -"""), "Foo", Str("bar")) +"""), + "Foo", + Str("bar") + ) } test("exercise option from predef") { @@ -107,7 +131,10 @@ x = Some(1) z = match x: case Some(v): add(v, 10) case None: 0 -"""), "Foo", VInt(11)) +"""), + "Foo", + VInt(11) + ) // Use a local name collision and see it not have a problem evalTest( @@ -121,7 +148,10 @@ x = Some(1) z = match x: case Some(v): add(v, 10) case None: 0 -"""), "Foo", VInt(11)) +"""), + "Foo", + VInt(11) + ) evalTest( List(""" @@ -135,7 +165,10 @@ x = Some(1) z = match x: case None: 0 case Some(v): add(v, 10) -"""), "Foo", VInt(11)) +"""), + "Foo", + VInt(11) + ) } test("test matching unions") { @@ -150,7 +183,10 @@ x = Pair(Pair(1, "1"), "2") main = match x: Pair(_, "2" | "3"): "good" _: "bad" -"""), "Foo", Str("good")) +"""), + "Foo", + Str("good") + ) evalTest( List("""package Foo @@ -164,7 +200,10 @@ def run(z): y main = run(x) -"""), "Foo", Str("good")) +"""), + "Foo", + Str("good") + ) evalTest( List(""" @@ -179,10 +218,12 @@ def run(z): y main = run(x) -"""), "Foo", Str("good")) +"""), + "Foo", + Str("good") + ) - evalFail( - List(""" + evalFail(List(""" package Err enum IntOrString: IntCase(i: Int), StringCase(i: Int, s: String) @@ -207,10 +248,15 @@ def go(x): main = go(IntCase(42)) """ - val packs = Map((PackageName.parts("Err"), (LocationMap(errPack), "Err.bosatsu"))) - evalFail(List(errPack)) { case te@PackageError.TypeErrorIn(_, _) => + val packs = + Map((PackageName.parts("Err"), (LocationMap(errPack), "Err.bosatsu"))) + evalFail(List(errPack)) { case te @ PackageError.TypeErrorIn(_, _) => val msg = te.message(packs, Colorize.None) - assert(msg.contains("type error: expected type Bosatsu/Predef::Int to be the same as type Bosatsu/Predef::String")) + assert( + msg.contains( + "type error: expected type Bosatsu/Predef::Int to be the same as type Bosatsu/Predef::String" + ) + ) () } @@ -227,7 +273,10 @@ def go(x): 42 main = go(IntCase(42)) -"""), "Union", VInt(42)) +"""), + "Union", + VInt(42) + ) } test("test matching literals") { @@ -240,7 +289,10 @@ x = 1 main = match x: case 1: "good" case _: "bad" -"""), "Foo", Str("good")) +"""), + "Foo", + Str("good") + ) evalTest( List(""" @@ -252,7 +304,10 @@ x = [1] main = match x: EmptyList: "empty" NonEmptyList(...): "notempty" -"""), "Foo", Str("notempty")) +"""), + "Foo", + Str("notempty") + ) evalTest( List(""" @@ -263,7 +318,10 @@ x = "1" main = match x: case "1": "good" case _: "bad" -"""), "Foo", Str("good")) +"""), + "Foo", + Str("good") + ) evalTest( List(""" @@ -276,7 +334,10 @@ x = Pair(1, "1") main = match x: case Pair(_, "1"): "good" case _: "bad" -"""), "Foo", Str("good")) +"""), + "Foo", + Str("good") + ) } test("test tuples") { @@ -289,7 +350,10 @@ x = (1, "1") main = match x: case (_, "1"): "good" case _: "bad" -"""), "Foo", Str("good")) +"""), + "Foo", + Str("good") + ) evalTest( List(""" @@ -306,7 +370,10 @@ def go(u): case _: "bad" main = go(()) -"""), "Foo", Str("good")) +"""), + "Foo", + Str("good") + ) } test("do a fold") { @@ -323,7 +390,10 @@ sum0 = sum(three) sum1 = three.foldLeft(0, (x, y) -> add(x, y)) same = sum0.eq_Int(sum1) -"""), "Foo", True) +"""), + "Foo", + True + ) evalTest( List(""" @@ -335,7 +405,10 @@ sum0 = three.foldLeft(0, add) sum1 = three.foldLeft(0, \x, y -> add(x, y)) same = sum0.eq_Int(sum1) -"""), "Foo", True) +"""), + "Foo", + True + ) } @@ -345,7 +418,10 @@ same = sum0.eq_Int(sum1) package Foo main = 6.mod_Int(4) -"""), "Foo", VInt(2)) +"""), + "Foo", + VInt(2) + ) evalTest( List(""" @@ -355,14 +431,20 @@ main = match 6.div(4): case 0: 42 case 1: 100 case x: x -"""), "Foo", VInt(100)) +"""), + "Foo", + VInt(100) + ) evalTest( List(""" package Foo main = 6.gcd_Int(3) -"""), "Foo", VInt(3)) +"""), + "Foo", + VInt(3) + ) } test("use range") { @@ -396,10 +478,13 @@ def eq_list(a, b, fn): same_items(zip(a, b), fn) same = eq_list(three, threer, eq_Int) -"""), "Foo", True) +"""), + "Foo", + True + ) -evalTest( - List(""" + evalTest( + List(""" package Foo def zip(as: List[a], bs: List[b]) -> List[(a, b)]: @@ -411,31 +496,43 @@ def zip(as: List[a], bs: List[b]) -> List[(a, b)]: case [bh, *btail]: [(ah, bh), *zip(atail, btail)] main = 1 -"""), "Foo", VInt(1)) +"""), + "Foo", + VInt(1) + ) } test("test range_fold") { -evalTest( - List(""" + evalTest( + List(""" package Foo main = range_fold(0, 10, 0, add) -"""), "Foo", VInt(45)) +"""), + "Foo", + VInt(45) + ) -evalTest( - List(""" + evalTest( + List(""" package Foo main = range_fold(0, 10, 0, (_, y) -> y) -"""), "Foo", VInt(9)) +"""), + "Foo", + VInt(9) + ) -evalTest( - List(""" + evalTest( + List(""" package Foo main = range_fold(0, 10, 100, (x, _) -> x) -"""), "Foo", VInt(100)) +"""), + "Foo", + VInt(100) + ) } test("test some list matches") { @@ -449,7 +546,10 @@ def headOption(as): case [a, *_]: Some(a) main = headOption([1]) -"""), "Foo", SumValue(1, ConsValue(VInt(1), UnitValue))) +"""), + "Foo", + SumValue(1, ConsValue(VInt(1), UnitValue)) + ) runBosatsuTest( List(""" @@ -467,7 +567,10 @@ test = TestSuite("exists", [ Assertion(not(exists([])), "![]"), Assertion(not(exists([False])), "![False]"), ]) -"""), "Foo", 5) +"""), + "Foo", + 5 + ) } test("test generics in defs") { @@ -479,7 +582,10 @@ def id(x: a) -> a: x main = id(1) -"""), "Foo", VInt(1)) +"""), + "Foo", + VInt(1) + ) } test("exercise struct creation") { @@ -490,8 +596,10 @@ package Foo struct Bar(a: Int) main = Bar(1) -"""), "Foo", - VInt(1)) +"""), + "Foo", + VInt(1) + ) evalTest( List(""" @@ -501,7 +609,10 @@ struct Bar(a: Int) # destructuring top-level let Bar(main) = Bar(1) -"""), "Foo", VInt(1)) +"""), + "Foo", + VInt(1) + ) evalTest( List(""" @@ -511,7 +622,10 @@ struct Bar(a: Int) # destructuring top-level let Bar(main: Int) = Bar(1) -"""), "Foo", VInt(1)) +"""), + "Foo", + VInt(1) + ) evalTest( List(""" @@ -522,7 +636,10 @@ struct Bar(a: Int) y = Bar(1) # destructuring top-level let Bar(main: Int) = y -"""), "Foo", VInt(1)) +"""), + "Foo", + VInt(1) + ) evalTestJson( List(""" @@ -531,12 +648,16 @@ package Foo struct Bar(a: Int, s: String) main = Bar(1, "foo") -"""), "Foo", Json.JObject(List("a" -> Json.JNumberStr("1"), "s" -> Json.JString("foo")))) +"""), + "Foo", + Json.JObject( + List("a" -> Json.JNumberStr("1"), "s" -> Json.JString("foo")) + ) + ) } test("test some type errors") { - evalFail( - List(""" + evalFail(List(""" package Foo main = if True: @@ -546,7 +667,9 @@ else: """)) { case PackageError.TypeErrorIn(_, _) => () } } - test("test the list literals work even when we have conflicting local names") { + test( + "test the list literals work even when we have conflicting local names" + ) { evalTest( List(""" package Foo @@ -554,8 +677,10 @@ package Foo struct EmptyList main = [1, 2] -"""), "Foo", - VList.Cons(VInt(1), VList.Cons(VInt(2), VList.VNil))) +"""), + "Foo", + VList.Cons(VInt(1), VList.Cons(VInt(2), VList.VNil)) + ) evalTest( List(""" @@ -564,8 +689,10 @@ package Foo struct NonEmptyList main = [1, 2] -"""), "Foo", - VList.Cons(VInt(1), VList.Cons(VInt(2), VList.VNil))) +"""), + "Foo", + VList.Cons(VInt(1), VList.Cons(VInt(2), VList.VNil)) + ) evalTest( List(""" @@ -574,13 +701,14 @@ package Foo def concat(a): a main = [1, 2] -"""), "Foo", - VList.Cons(VInt(1), VList.Cons(VInt(2), VList.VNil))) +"""), + "Foo", + VList.Cons(VInt(1), VList.Cons(VInt(2), VList.VNil)) + ) } test("forbid the y-combinator") { - evalFail( - List(""" + evalFail(List(""" package Y struct W(fn: W[a, b] -> a -> b) @@ -599,14 +727,14 @@ def ltEqZero(i): fac = trace("made fac", y(\f, i -> 1 if ltEqZero(i) else f(i).times(i))) main = fac(6) -""")) { - case PackageError.KindInferenceError(_, _, _) => () - } +""")) { case PackageError.KindInferenceError(_, _, _) => + () + } } test("check type aligned enum") { - evalTest( - List(""" + evalTest( + List(""" package A enum GoodOrBad: @@ -618,22 +746,27 @@ def unbox(gb: GoodOrBad[a]): case Bad(b): b (main: Int) = unbox(Good(42)) -"""), "A", VInt(42)) +"""), + "A", + VInt(42) + ) - evalTest( - List(""" + evalTest( + List(""" package A enum GoodOrBad: Bad(a: a), Good(a: a) Bad(main) | Good(main) = Good(42) -"""), "A", VInt(42)) +"""), + "A", + VInt(42) + ) } test("nontotal matches fail even if not at runtime") { - evalFail( - List(""" + evalFail(List(""" package Total enum Opt: Nope, Yep(get) @@ -649,8 +782,7 @@ main = one } test("unreachable patterns are an error") { - evalFail( - List(""" + evalFail(List(""" package Total enum Opt: Nope, Yep(get) @@ -668,8 +800,8 @@ main = one } test("Leibniz type equality example") { - evalTest( - List(""" + evalTest( + List(""" package A struct Leib(subst: forall f: * -> *. f[a] -> f[b]) @@ -698,11 +830,13 @@ def getValue(v: StringOrInt[a]) -> a: case IsInt(i, leib): coerce(i, leib) main = getValue(int) -"""), "A", VInt(42)) +"""), + "A", + VInt(42) + ) - // If we leave out the coerce it fails - evalFail( - List(""" + // If we leave out the coerce it fails + evalFail(List(""" package A struct Leib(subst: forall f. f[a] -> f[b]) @@ -724,13 +858,12 @@ def getValue(v): case IsInt(i, _): i main = getValue(int) -""")){ case PackageError.TypeErrorIn(_, _) => () } +""")) { case PackageError.TypeErrorIn(_, _) => () } } test("overly generic methods fail compilation") { - evalFail( - List(""" + evalFail(List(""" package A # this shouldn't compile, a is too generic @@ -738,12 +871,11 @@ def plus(x: a, y): x.add(y) main = plus(1, 2) -""")){ case PackageError.TypeErrorIn(_, _) => () } +""")) { case PackageError.TypeErrorIn(_, _) => () } } test("unused let fails compilation") { - evalFail( - List(""" + evalFail(List(""" package A # this shouldn't compile, z is unused @@ -752,7 +884,7 @@ def plus(x, y): x.add(y) main = plus(1, 2) -""")){ case le@PackageError.UnusedLetError(_, _) => +""")) { case le @ PackageError.UnusedLetError(_, _) => val msg = le.message(Map.empty, Colorize.None) assert(!msg.contains("Name(")) assert(msg.contains("unused let binding: z\n Region(68,73)")) @@ -761,8 +893,8 @@ main = plus(1, 2) } test("structual recursion is allowed") { - evalTest( - List(""" + evalTest( + List(""" package A def len(lst, acc): @@ -771,10 +903,13 @@ def len(lst, acc): [_, *tail]: len(tail, acc.add(1)) main = len([1, 2, 3], 0) -"""), "A", VInt(3)) +"""), + "A", + VInt(3) + ) - evalTest( - List(""" + evalTest( + List(""" package A enum PNat: One, Even(of: PNat), Odd(of: PNat) @@ -786,10 +921,12 @@ def toInt(pnat): Odd(of): toInt(of).times(2).add(1) main = toInt(Even(Even(One))) -"""), "A", VInt(4)) +"""), + "A", + VInt(4) + ) - evalFail( - List(""" + evalFail(List(""" package A enum Foo: Bar, Baz @@ -800,20 +937,23 @@ def bad(foo): baz: bad(baz) main = bad(Bar) -""")){ case PackageError.RecursionError(_, _) => () } +""")) { case PackageError.RecursionError(_, _) => () } - evalTest( - List(""" + evalTest( + List(""" package A big_list = range(3_000) main = big_list.foldLeft(0, add) -"""), "A", VInt((0 until 3000).sum)) +"""), + "A", + VInt((0 until 3000).sum) + ) - def sumFn(n: Int): Int = if (n <= 0) 0 else { sumFn(n-1) + n } - evalTest( - List(""" + def sumFn(n: Int): Int = if (n <= 0) 0 else { sumFn(n - 1) + n } + evalTest( + List(""" package A enum Nat: Zero, Succ(of: Nat) @@ -829,11 +969,14 @@ def sum(nat): Succ(n): sum(n).add(toInt(nat)) main = sum(Succ(Succ(Succ(Zero)))) -"""), "A", VInt(sumFn(3))) +"""), + "A", + VInt(sumFn(3)) + ) - // try with Succ first in the Nat - evalTest( - List(""" + // try with Succ first in the Nat + evalTest( + List(""" package A enum Nat: Zero, Succ(of: Nat) @@ -849,12 +992,15 @@ def sum(nat): Zero: 0 main = sum(Succ(Succ(Succ(Zero)))) -"""), "A", VInt(sumFn(3))) +"""), + "A", + VInt(sumFn(3)) + ) } test("we can mix literal and enum forms of List") { - evalTest( - List(""" + evalTest( + List(""" package A def len(lst, acc): @@ -863,9 +1009,12 @@ def len(lst, acc): [_, *tail]: len(tail, acc.add(1)) main = len([1, 2, 3], 0) -"""), "A", VInt(3)) - evalTest( - List(""" +"""), + "A", + VInt(3) + ) + evalTest( + List(""" package A def len(lst, acc): @@ -874,47 +1023,65 @@ def len(lst, acc): NonEmptyList(_, tail): len(tail, acc.add(1)) main = len([1, 2, 3], 0) -"""), "A", VInt(3)) +"""), + "A", + VInt(3) + ) } test("list comphension test") { - evalTest( - List(""" + evalTest( + List(""" package A main = [x for x in range(4)].foldLeft(0, add) -"""), "A", VInt(6)) - evalTest( - List(""" +"""), + "A", + VInt(6) + ) + evalTest( + List(""" package A main = [*[x] for x in range(4)].foldLeft(0, add) -"""), "A", VInt(6)) +"""), + "A", + VInt(6) + ) - evalTest( - List(""" + evalTest( + List(""" package A doub = [(x, x) for x in range(4)] main = [x.times(y) for (x, y) in doub].foldLeft(0, add) -"""), "A", VInt(1 + 4 + 9)) - evalTest( - List(""" +"""), + "A", + VInt(1 + 4 + 9) + ) + evalTest( + List(""" package A main = [x for x in range(4) if x.eq_Int(2)].foldLeft(0, add) -"""), "A", VInt(2)) +"""), + "A", + VInt(2) + ) - evalTest( - List(""" + evalTest( + List(""" package A main = [*[x, x] for x in range(4) if x.eq_Int(2)].foldLeft(0, add) -"""), "A", VInt(4)) +"""), + "A", + VInt(4) + ) - evalTest( - List(""" + evalTest( + List(""" package A def eq_List(lst1, lst2): @@ -936,12 +1103,15 @@ lst3 = [*[y, y] for (_, y) in [(x, x) for x in range(4)]] main = match (eq_List(lst1, lst2), eq_List(lst1, lst3)): case (True, True): 1 case _ : 0 -"""), "A", VInt(1)) +"""), + "A", + VInt(1) + ) } test("test fib using recursion") { - evalTest( - List(""" + evalTest( + List(""" package A enum Nat: Z, S(p: Nat) @@ -954,10 +1124,13 @@ def fib(n): # fib(5) = 1, 1, 2, 3, 5, 8 main = fib(S(S(S(S(S(Z)))))) -"""), "A", VInt(8)) +"""), + "A", + VInt(8) + ) - evalTest( - List(""" + evalTest( + List(""" package A enum Nat[a]: Z, S(p: Nat[a]) @@ -970,10 +1143,13 @@ def fib(n): # fib(5) = 1, 1, 2, 3, 5, 8 main = fib(S(S(S(S(S(Z)))))) -"""), "A", VInt(8)) +"""), + "A", + VInt(8) + ) - evalTest( - List(""" + evalTest( + List(""" package A enum Nat: S(p: Nat), Z @@ -986,11 +1162,15 @@ def fib(n): # fib(5) = 1, 1, 2, 3, 5, 8 main = fib(S(S(S(S(S(Z)))))) -"""), "A", VInt(8)) +"""), + "A", + VInt(8) + ) } test("test matching the front of a list") { - evalTest(List(""" + evalTest( + List(""" package A def bad_len(list): @@ -999,9 +1179,13 @@ def bad_len(list): case [*init, _]: bad_len(init).add(1) main = bad_len([1, 2, 3, 5]) -"""), "A", VInt(4)) +"""), + "A", + VInt(4) + ) - evalTest(List(""" + evalTest( + List(""" package A def last(list): @@ -1010,10 +1194,14 @@ def last(list): case [*_, s]: s main = last([1, 2, 3, 5]) -"""), "A", VInt(5)) +"""), + "A", + VInt(5) + ) } test("test a named pattern that doesn't match") { - evalTest(List(""" + evalTest( + List(""" package A def bad_len(list): @@ -1031,10 +1219,14 @@ def bad_len(list): bad_len(init).add(1) main = bad_len([1, 2, 3, 5]) -"""), "A", VInt(4)) +"""), + "A", + VInt(4) + ) } test("uncurry2") { - evalTest(List(""" + evalTest( + List(""" package A struct TwoVar(one, two) @@ -1044,10 +1236,14 @@ constructed = uncurry2(x -> y -> TwoVar(x, y))(1, "two") main = match constructed: case TwoVar(1, "two"): "good" case _: "bad" -"""), "A", Str("good")) +"""), + "A", + Str("good") + ) } test("uncurry3") { - evalTest(List(""" + evalTest( + List(""" package A struct ThreeVar(one, two, three) @@ -1057,11 +1253,15 @@ constructed = uncurry3(x -> y -> z -> ThreeVar(x, y, z))(1, "two", 3) main = match constructed: case ThreeVar(1, "two", 3): "good" case _: "bad" -"""), "A", Str("good")) +"""), + "A", + Str("good") + ) } test("Dict methods") { - evalTest(List(""" + evalTest( + List(""" package A e = empty_Dict(string_Order) @@ -1069,9 +1269,13 @@ e = empty_Dict(string_Order) e1 = e.add_key("hello", "world") main = e1.get_key("hello") -"""), "A", VOption.some(Str("world"))) +"""), + "A", + VOption.some(Str("world")) + ) - evalTest(List(""" + evalTest( + List(""" package A e = empty_Dict(string_Order) @@ -1079,9 +1283,13 @@ e = empty_Dict(string_Order) e1 = e.clear_Dict().add_key("hello2", "world2") main = e1.get_key("hello") -"""), "A", VOption.none) +"""), + "A", + VOption.none + ) - evalTest(List(""" + evalTest( + List(""" package A e = empty_Dict(string_Order) @@ -1090,9 +1298,13 @@ e1 = e.add_key("hello", "world") e2 = e1.remove_key("hello") main = e2.get_key("hello") -"""), "A", VOption.none) +"""), + "A", + VOption.none + ) - evalTest(List(""" + evalTest( + List(""" package A e1 = empty_Dict(string_Order) @@ -1102,9 +1314,13 @@ lst = e2.items() main = match lst: case [("hello", "world"), ("hello1", "world1")]: "good" case _: "bad" -"""), "A", Str("good")) +"""), + "A", + Str("good") + ) - evalTest(List(""" + evalTest( + List(""" package A e1 = {} @@ -1114,9 +1330,13 @@ lst = e2.items() main = match lst: case [("hello", "world"), ("hello1", "world1")]: "good" case _: "bad" -"""), "A", Str("good")) +"""), + "A", + Str("good") + ) - evalTest(List(""" + evalTest( + List(""" package A e = { @@ -1128,9 +1348,13 @@ lst = e.items() main = match lst: case [("hello", "world"), ("hello1", "world1")]: "good" case _: "bad" -"""), "A", Str("good")) +"""), + "A", + Str("good") + ) - evalTest(List(""" + evalTest( + List(""" package A pairs = [("hello", "world"), ("hello1", "world1")] @@ -1141,9 +1365,13 @@ lst = e.items() main = match lst: case [("hello", "world"), ("hello1", "world1")]: "good" case _: "bad" -"""), "A", Str("good")) +"""), + "A", + Str("good") + ) - evalTest(List(""" + evalTest( + List(""" package A pairs = [("hello", 42), ("hello1", 24)] @@ -1159,7 +1387,10 @@ lst = e.items() main = match lst: case [("hello", res)]: res case _: -1 -"""), "A", VInt(42)) +"""), + "A", + VInt(42) + ) evalTestJson( List(""" @@ -1168,7 +1399,10 @@ package Foo bar = {'a': '1', 's': 'foo' } main = bar -"""), "Foo", Json.JObject(List("a" -> Json.JString("1"), "s" -> Json.JString("foo")))) +"""), + "Foo", + Json.JObject(List("a" -> Json.JString("1"), "s" -> Json.JString("foo"))) + ) evalTestJson( List(""" @@ -1178,7 +1412,10 @@ package Foo bar: Dict[String, Option[Int]] = {'a': None, 's': None } main = bar -"""), "Foo", Json.JObject(List("a" -> Json.JNull, "s" -> Json.JNull))) +"""), + "Foo", + Json.JObject(List("a" -> Json.JNull, "s" -> Json.JNull)) + ) evalTestJson( List(""" @@ -1187,7 +1424,10 @@ package Foo bar = {'a': None, 's': Some(1) } main = bar -"""), "Foo", Json.JObject(List("a" -> Json.JNull, "s" -> Json.JNumberStr("1")))) +"""), + "Foo", + Json.JObject(List("a" -> Json.JNull, "s" -> Json.JNumberStr("1"))) + ) evalTestJson( List(""" @@ -1196,9 +1436,15 @@ package Foo bar = {'a': [], 's': [1] } main = bar -"""), "Foo", Json.JObject( - List("a" -> Json.JArray(Vector.empty), - "s" -> Json.JArray(Vector(Json.JNumberStr("1")))))) +"""), + "Foo", + Json.JObject( + List( + "a" -> Json.JArray(Vector.empty), + "s" -> Json.JArray(Vector(Json.JNumberStr("1"))) + ) + ) + ) evalTestJson( List(""" @@ -1207,32 +1453,36 @@ package Foo bar = {'a': True, 's': False } main = bar -"""), "Foo", Json.JObject( - List("a" -> Json.JBool(true), - "s" -> Json.JBool(false)))) +"""), + "Foo", + Json.JObject(List("a" -> Json.JBool(true), "s" -> Json.JBool(false))) + ) evalTestJson( List(""" package Foo main = (1, "1", ()) -"""), "Foo", Json.JArray( - Vector(Json.JNumberStr("1"), - Json.JString("1"), - Json.JNull))) +"""), + "Foo", + Json.JArray(Vector(Json.JNumberStr("1"), Json.JString("1"), Json.JNull)) + ) evalTestJson( List(""" package Foo main = [Some(Some(1)), Some(None), None] -"""), "Foo", - Json.JArray( - Vector( - Json.JArray(Vector(Json.JNumberStr("1"))), - Json.JArray(Vector(Json.JNull)), - Json.JArray(Vector.empty) - ))) +"""), + "Foo", + Json.JArray( + Vector( + Json.JArray(Vector(Json.JNumberStr("1"))), + Json.JArray(Vector(Json.JNull)), + Json.JArray(Vector.empty) + ) + ) + ) evalTestJson( List(""" @@ -1241,13 +1491,15 @@ package Foo enum FooBar: Foo(foo), Bar(bar) main = [Foo(1), Bar("1")] -"""), "Foo", - Json.JArray( - Vector( - Json.JObject( - List("foo" -> Json.JNumberStr("1"))), - Json.JObject( - List("bar" -> Json.JString("1")))))) +"""), + "Foo", + Json.JArray( + Vector( + Json.JObject(List("foo" -> Json.JNumberStr("1"))), + Json.JObject(List("bar" -> Json.JString("1"))) + ) + ) + ) } test("json handling of Nat special case") { @@ -1258,12 +1510,12 @@ package Foo enum Nat: Z, S(n: Nat) main = [Z, S(Z), S(S(Z))] -"""), "Foo", - Json.JArray( - Vector( - Json.JNumberStr("0"), - Json.JNumberStr("1"), - Json.JNumberStr("2")))) +"""), + "Foo", + Json.JArray( + Vector(Json.JNumberStr("0"), Json.JNumberStr("1"), Json.JNumberStr("2")) + ) + ) } test("json with backticks") { @@ -1276,29 +1528,37 @@ struct Foo(`struct`, `second key`, `enum`, `def`) `package` = 2 main = Foo(1, `package`, 3, 4) -"""), "Foo", - Json.JObject( - List( - ("struct" -> Json.JNumberStr("1")), - ("second key" -> Json.JNumberStr("2")), - ("enum" -> Json.JNumberStr("3")), - ("def" -> Json.JNumberStr("4"))) - )) +"""), + "Foo", + Json.JObject( + List( + ("struct" -> Json.JNumberStr("1")), + ("second key" -> Json.JNumberStr("2")), + ("enum" -> Json.JNumberStr("3")), + ("def" -> Json.JNumberStr("4")) + ) + ) + ) } test("test operators") { - evalTest(List(""" + evalTest( + List(""" package A operator + = add operator * = times main = 1 + 2 * 3 -"""), "A", VInt(7)) +"""), + "A", + VInt(7) + ) } test("patterns in lambdas") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A # you can't write \x: Int -> x.add(1) @@ -1308,17 +1568,25 @@ inc: Int -> Int = x -> x.add(1) inc2: Int -> Int = (x: Int) -> x.add(1) test = Assertion(inc(1).eq_Int(inc2(1)), "inc(1) == 2") -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A def inc(x: Int): x.add(1) test = Assertion(inc(1).eq_Int(2), "inc(1) == 2") -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A struct Foo(v) @@ -1345,9 +1613,13 @@ test5 = Assertion(inc4(Pair(F(1), Foo(1))).eq_Int(2), "inc4(Pair(F(1), Foo(1))) test6 = Assertion(inc4(Pair(B(1), Foo(1))).eq_Int(2), "inc4(Pair(B(1), Foo(1))) == 2") suite = TestSuite("match tests", [test0, test1, test2, test3, test4, test5, test6]) -"""), "A", 7) +"""), + "A", + 7 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A struct Foo(v) @@ -1374,96 +1646,111 @@ test5 = Assertion(inc4(Pair(F(1), Foo(1))).eq_Int(2), "inc4(Pair(F(1), Foo(1))) test6 = Assertion(inc4(Pair(B(1), Foo(1))).eq_Int(2), "inc4(Pair(B(1), Foo(1))) == 2") suite = TestSuite("match tests", [test0, test1, test2, test3, test4, test5, test6]) -"""), "A", 7) +"""), + "A", + 7 + ) } test("test some error messages") { evalFail( - List(""" + List( + """ package A a = 1 -""", """ +""", + """ package B from A import a -main = a""")) { case PackageError.UnknownImportName(_, _, _, _, _) => () } +main = a""" + ) + ) { case PackageError.UnknownImportName(_, _, _, _, _) => () } - evalFail( - List(""" + evalFail(List(""" package B from A import a main = a""")) { case PackageError.UnknownImportPackage(_, _) => () } - evalFail( - List(""" + evalFail(List(""" package B -main = a""")) { case te@PackageError.TypeErrorIn(_, _) => - val msg = te.message(Map.empty, Colorize.None) - assert(!msg.contains("Name(")) - assert(msg.contains("package B\nname \"a\" unknown")) - () - } +main = a""")) { case te @ PackageError.TypeErrorIn(_, _) => + val msg = te.message(Map.empty, Colorize.None) + assert(!msg.contains("Name(")) + assert(msg.contains("package B\nname \"a\" unknown")) + () + } - evalFail( - List(""" + evalFail(List(""" package B x = 1 main = match x: case Foo: 2 -""")) { case te@PackageError.SourceConverterErrorIn(_, _) => - val msg = te.message(Map.empty, Colorize.None) - assert(!msg.contains("Name(")) - assert(msg.contains("package B\nunknown constructor Foo")) - () - } +""")) { case te @ PackageError.SourceConverterErrorIn(_, _) => + val msg = te.message(Map.empty, Colorize.None) + assert(!msg.contains("Name(")) + assert(msg.contains("package B\nunknown constructor Foo")) + () + } - evalFail( - List(""" + evalFail(List(""" package B struct X main = match 1: case X1: 0 -""")) { case te@PackageError.SourceConverterErrorIn(_, _) => - assert(te.message(Map.empty, Colorize.None) == "in file: , package B\nunknown constructor X1\nRegion(49,50)") +""")) { case te @ PackageError.SourceConverterErrorIn(_, _) => + assert( + te.message( + Map.empty, + Colorize.None + ) == "in file: , package B\nunknown constructor X1\nRegion(49,50)" + ) () } - evalFail( - List(""" + evalFail(List(""" package A main = match [1, 2, 3]: case []: 0 case [*a, *b, _]: 2 -""")) { case te@PackageError.TotalityCheckError(_, _) => - assert(te.message(Map.empty, Colorize.None) == "in file: , package A\nRegion(19,70)\nmultiple splices in pattern, only one per match allowed") +""")) { case te @ PackageError.TotalityCheckError(_, _) => + assert( + te.message( + Map.empty, + Colorize.None + ) == "in file: , package A\nRegion(19,70)\nmultiple splices in pattern, only one per match allowed" + ) () } - evalFail( - List(""" + evalFail(List(""" package A enum Foo: Bar(a), Baz(b) main = match Bar(a): case Baz(b): b -""")) { case te@PackageError.TotalityCheckError(_, _) => - assert(te.message(Map.empty, Colorize.None) == "in file: , package A\nRegion(45,75)\nnon-total match, missing: Bar(_)") +""")) { case te @ PackageError.TotalityCheckError(_, _) => + assert( + te.message( + Map.empty, + Colorize.None + ) == "in file: , package A\nRegion(45,75)\nnon-total match, missing: Bar(_)" + ) () } - evalFail( - List(""" + evalFail(List(""" package A def fn(x): @@ -1471,13 +1758,17 @@ def fn(x): y: 0 main = fn -""")) { case te@PackageError.RecursionError(_, _) => - assert(te.message(Map.empty, Colorize.None) == "in file: , package A\nrecur but no recursive call to fn\nRegion(25,42)\n") +""")) { case te @ PackageError.RecursionError(_, _) => + assert( + te.message( + Map.empty, + Colorize.None + ) == "in file: , package A\nrecur but no recursive call to fn\nRegion(25,42)\n" + ) () } - evalFail( - List(""" + evalFail(List(""" package A def fn(x): @@ -1485,13 +1776,17 @@ def fn(x): y: 0 main = fn -""")) { case te@PackageError.RecursionError(_, _) => - assert(te.message(Map.empty, Colorize.None) == "in file: , package A\nrecur not on an argument to the def of fn, args: (x)\nRegion(25,43)\n") +""")) { case te @ PackageError.RecursionError(_, _) => + assert( + te.message( + Map.empty, + Colorize.None + ) == "in file: , package A\nrecur not on an argument to the def of fn, args: (x)\nRegion(25,43)\n" + ) () } - evalFail( - List(""" + evalFail(List(""" package A def fn(x): @@ -1499,13 +1794,17 @@ def fn(x): y: 0 main = fn -""")) { case te@PackageError.RecursionError(_, _) => - assert(te.message(Map.empty, Colorize.None) == "in file: , package A\nrecur not on an argument to the def of fn, args: (x)\nRegion(25,42)\n") +""")) { case te @ PackageError.RecursionError(_, _) => + assert( + te.message( + Map.empty, + Colorize.None + ) == "in file: , package A\nrecur not on an argument to the def of fn, args: (x)\nRegion(25,42)\n" + ) () } - evalFail( - List(""" + evalFail(List(""" package A def fn(x): @@ -1515,13 +1814,17 @@ def fn(x): z: 100 main = fn -""")) { case te@PackageError.RecursionError(_, _) => - assert(te.message(Map.empty, Colorize.None) == "in file: , package A\nunexpected recur: may only appear unnested inside a def\nRegion(47,70)\n") +""")) { case te @ PackageError.RecursionError(_, _) => + assert( + te.message( + Map.empty, + Colorize.None + ) == "in file: , package A\nunexpected recur: may only appear unnested inside a def\nRegion(47,70)\n" + ) () } - evalFail( - List(""" + evalFail(List(""" package A def fn(x): @@ -1532,13 +1835,17 @@ def fn(x): z: 100 main = fn -""")) { case te@PackageError.RecursionError(_, _) => - assert(te.message(Map.empty, Colorize.None) == "in file: , package A\nillegal shadowing on: fn. Recursive shadowing of def names disallowed\nRegion(25,81)\n") +""")) { case te @ PackageError.RecursionError(_, _) => + assert( + te.message( + Map.empty, + Colorize.None + ) == "in file: , package A\nillegal shadowing on: fn. Recursive shadowing of def names disallowed\nRegion(25,81)\n" + ) () } - evalFail( - List(""" + evalFail(List(""" package A def fn(x, y): @@ -1547,13 +1854,17 @@ def fn(x, y): case x: fn(x - 1, y + 1) main = fn -""")) { case te@PackageError.RecursionError(_, _) => - assert(te.message(Map.empty, Colorize.None) == "in file: , package A\ninvalid recursion on fn\nRegion(63,79)\n") +""")) { case te @ PackageError.RecursionError(_, _) => + assert( + te.message( + Map.empty, + Colorize.None + ) == "in file: , package A\ninvalid recursion on fn\nRegion(63,79)\n" + ) () } - evalFail( - List(""" + evalFail(List(""" package A def fn(x, y): @@ -1562,13 +1873,16 @@ def fn(x, y): case x: x main = fn(0, 1, 2) -""")) { case te@PackageError.TypeErrorIn(_, _) => - assert(te.message(Map.empty, Colorize.None).contains("does not match function with 3 arguments at:")) +""")) { case te @ PackageError.TypeErrorIn(_, _) => + assert( + te.message(Map.empty, Colorize.None) + .contains("does not match function with 3 arguments at:") + ) () } // we should have the region set inside - val code1571 = """ + val code1571 = """ package A def fn(x): @@ -1578,48 +1892,72 @@ def fn(x): main = fn([1, 2]) """ - evalFail(code1571 :: Nil) { case te@PackageError.TypeErrorIn(_, _) => + evalFail(code1571 :: Nil) { case te @ PackageError.TypeErrorIn(_, _) => // Make sure we point at the function directly assert(code1571.substring(67, 69) == "fn") - assert(te.message(Map.empty, Colorize.None) - .contains("the first type is a function with one argument and the second is a function with 2 arguments")) - assert(te.message(Map.empty, Colorize.None) - .contains("Region(67,69)")) + assert( + te.message(Map.empty, Colorize.None) + .contains( + "the first type is a function with one argument and the second is a function with 2 arguments" + ) + ) + assert( + te.message(Map.empty, Colorize.None) + .contains("Region(67,69)") + ) () } evalFail( - List(""" + List( + """ package A export foo foo = 3 -""", """ +""", + """ package B from A import fooz baz = fooz -""")) { case te@PackageError.UnknownImportName(_, _, _, _, _) => - assert(te.message(Map.empty, Colorize.None) == "in package: A does not have name fooz. Nearest: foo") +""" + ) + ) { case te @ PackageError.UnknownImportName(_, _, _, _, _) => + assert( + te.message( + Map.empty, + Colorize.None + ) == "in package: A does not have name fooz. Nearest: foo" + ) () } evalFail( - List(""" + List( + """ package A export foo foo = 3 bar = 3 -""", """ +""", + """ package B from A import bar baz = bar -""")) { case te@PackageError.UnknownImportName(_, _, _, _, _) => - assert(te.message(Map.empty, Colorize.None) == "in package: A has bar but it is not exported. Add to exports") +""" + ) + ) { case te @ PackageError.UnknownImportName(_, _, _, _, _) => + assert( + te.message( + Map.empty, + Colorize.None + ) == "in package: A has bar but it is not exported. Add to exports" + ) () } } @@ -1643,7 +1981,10 @@ tests = TestSuite("test triple", [ Assertion(a.eq_Int(3), "a == 3"), Assertion(bgood, b), Assertion(c.eq_Int(1), "c == 1") ]) -"""), "A", 3) +"""), + "A", + 3 + ) } test("regression from a map_List/list comprehension example from snoble") { @@ -1790,7 +2131,10 @@ tests = TestSuite("reordering", Assertion(equal_rows.equal_List(rs0.list_of_rows(), [[REBool(RecordValue(False)), REInt(RecordValue(1)), REString(RecordValue("a")), REInt(RecordValue(3))]]), "swap") ] ) -"""), "RecordSet/Library", 1) +"""), + "RecordSet/Library", + 1 + ) } test("record patterns") { @@ -1807,7 +2151,10 @@ tests = TestSuite("test record", [ Assertion(f2.eq_Int(1), "f2 == 1"), ]) -"""), "A", 1) +"""), + "A", + 1 + ) runBosatsuTest( List(""" @@ -1824,7 +2171,10 @@ tests = TestSuite("test record", [ Assertion(res.eq_Int(1), "res == 1"), ]) -"""), "A", 1) +"""), + "A", + 1 + ) runBosatsuTest( List(""" @@ -1840,7 +2190,10 @@ tests = TestSuite("test record", [ Assertion(res.eq_Int(1), "res == 1"), ]) -"""), "A", 1) +"""), + "A", + 1 + ) runBosatsuTest( List(""" @@ -1856,7 +2209,10 @@ tests = TestSuite("test record", [ Assertion(res.eq_Int(1), "res == 1"), ]) -"""), "A", 1) +"""), + "A", + 1 + ) runBosatsuTest( List(""" @@ -1872,7 +2228,10 @@ tests = TestSuite("test record", [ Assertion(res.eq_Int(1), "res == 1"), ]) -"""), "A", 1) +"""), + "A", + 1 + ) runBosatsuTest( List(""" @@ -1888,7 +2247,10 @@ tests = TestSuite("test record", [ Assertion(res.eq_Int(1), "res == 1"), ]) -"""), "A", 1) +"""), + "A", + 1 + ) runBosatsuTest( List(""" @@ -1906,10 +2268,12 @@ tests = TestSuite("test record", [ Assertion(res.eq_Int(1), "res == 1"), ]) -"""), "A", 1) +"""), + "A", + 1 + ) - evalFail( - List(""" + evalFail(List(""" package A struct Pair(first, second) @@ -1919,10 +2283,11 @@ get = Pair(first, ...) -> first # missing second first = 1 res = get(Pair { first }) -""")) { case s@PackageError.SourceConverterErrorIn(_, _) => s.message(Map.empty, Colorize.None); () } +""")) { case s @ PackageError.SourceConverterErrorIn(_, _) => + s.message(Map.empty, Colorize.None); () + } - evalFail( - List(""" + evalFail(List(""" package A struct Pair(first, second) @@ -1933,10 +2298,11 @@ get = Pair(first, ...) -> first first = 1 second = 3 res = get(Pair { first, second, third }) -""")) { case s@PackageError.SourceConverterErrorIn(_, _) => s.message(Map.empty, Colorize.None); () } +""")) { case s @ PackageError.SourceConverterErrorIn(_, _) => + s.message(Map.empty, Colorize.None); () + } - evalFail( - List(""" + evalFail(List(""" package A struct Pair(first, second) @@ -1944,10 +2310,11 @@ struct Pair(first, second) get = Pair { first } -> first res = get(Pair(1, "two")) -""")) { case s@PackageError.SourceConverterErrorIn(_, _) => s.message(Map.empty, Colorize.None); () } +""")) { case s @ PackageError.SourceConverterErrorIn(_, _) => + s.message(Map.empty, Colorize.None); () + } - evalFail( - List(""" + evalFail(List(""" package A struct Pair(first, second) @@ -1956,10 +2323,11 @@ struct Pair(first, second) get = Pair(first) -> first res = get(Pair(1, "two")) -""")) { case s@PackageError.SourceConverterErrorIn(_, _) => s.message(Map.empty, Colorize.None); () } +""")) { case s @ PackageError.SourceConverterErrorIn(_, _) => + s.message(Map.empty, Colorize.None); () + } - evalFail( - List(""" + evalFail(List(""" package A struct Pair(first, second) @@ -1968,10 +2336,11 @@ struct Pair(first, second) get = \Pair { first, sec: _ } -> first res = get(Pair(1, "two")) -""")) { case s@PackageError.SourceConverterErrorIn(_, _) => s.message(Map.empty, Colorize.None); () } +""")) { case s @ PackageError.SourceConverterErrorIn(_, _) => + s.message(Map.empty, Colorize.None); () + } - evalFail( - List(""" + evalFail(List(""" package A struct Pair(first, second) @@ -1980,10 +2349,11 @@ struct Pair(first, second) get = Pair { first, sec: _, ... } -> first res = get(Pair(1, "two")) -""")) { case s@PackageError.SourceConverterErrorIn(_, _) => s.message(Map.empty, Colorize.None); () } +""")) { case s @ PackageError.SourceConverterErrorIn(_, _) => + s.message(Map.empty, Colorize.None); () + } - evalFail( - List(""" + evalFail(List(""" package A struct Pair(first, second) @@ -1992,10 +2362,11 @@ struct Pair(first, second) get = Pair(first, _, _) -> first res = get(Pair(1, "two")) -""")) { case s@PackageError.SourceConverterErrorIn(_, _) => s.message(Map.empty, Colorize.None); () } +""")) { case s @ PackageError.SourceConverterErrorIn(_, _) => + s.message(Map.empty, Colorize.None); () + } - evalFail( - List(""" + evalFail(List(""" package A struct Pair(first, second) @@ -2004,11 +2375,14 @@ struct Pair(first, second) get = Pair(first, _, _, ...) -> first res = get(Pair(1, "two")) -""")) { case s@PackageError.SourceConverterErrorIn(_, _) => s.message(Map.empty, Colorize.None); () } +""")) { case s @ PackageError.SourceConverterErrorIn(_, _) => + s.message(Map.empty, Colorize.None); () + } } test("exercise total matching inside of a struct with a list") { - runBosatsuTest(List("""package A + runBosatsuTest( + List("""package A struct ListWrapper(items: List[a], b: Bool) @@ -2017,9 +2391,13 @@ w = ListWrapper([], True) ListWrapper([*_], r) = w tests = Assertion(r, "match with total list pattern") -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List("""package A + runBosatsuTest( + List("""package A struct ListWrapper2(items: List[a], others: List[b], b: Bool) @@ -2028,9 +2406,13 @@ w = ListWrapper2([], [], True) ListWrapper2(_, _, r) = w tests = Assertion(r, "match with total list pattern") -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List("""package A + runBosatsuTest( + List("""package A struct ListWrapper(items: List[(a, b)], b: Bool) @@ -2039,12 +2421,16 @@ w = ListWrapper([], True) ListWrapper(_, r) = w tests = Assertion(r, "match with total list pattern") -"""), "A", 1) +"""), + "A", + 1 + ) } test("test scoping bug (issue #311)") { - runBosatsuTest(List("""package A + runBosatsuTest( + List("""package A struct Foo(x, y) @@ -2054,9 +2440,13 @@ tests = TestSuite("test record", [ Assertion(x.eq_Int(42), "x == 42"), ]) -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List("""package A + runBosatsuTest( + List("""package A struct Foo(x, y) @@ -2068,9 +2458,13 @@ tests = TestSuite("test record", [ Assertion(x.eq_Int(42), "x == 42"), ]) -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List("""package A + runBosatsuTest( + List("""package A struct Foo(x, y) @@ -2091,12 +2485,16 @@ tests = TestSuite("test record", [ Assertion(y.eq_Int(43), "y == 43"), ]) -"""), "A", 1) +"""), + "A", + 1 + ) } test("test ordered shadowing issue #328") { - runBosatsuTest(List("""package A + runBosatsuTest( + List("""package A one = 1 @@ -2112,10 +2510,14 @@ tests = TestSuite("test", [ Assertion(good, ""), ]) -"""), "A", 1) +"""), + "A", + 1 + ) // test record syntax - runBosatsuTest(List("""package A + runBosatsuTest( + List("""package A struct Foo(one) @@ -2135,10 +2537,14 @@ tests = TestSuite("test", [ Assertion(good, ""), ]) -"""), "A", 1) +"""), + "A", + 1 + ) // test local shadowing of a duplicate - runBosatsuTest(List("""package A + runBosatsuTest( + List("""package A one = 1 @@ -2157,10 +2563,14 @@ tests = TestSuite("test", [ Assertion(good, ""), ]) -"""), "A", 1) +"""), + "A", + 1 + ) // test an example using a predef function, like add - runBosatsuTest(List("""package A + runBosatsuTest( + List("""package A # this should be add from predef two = add(1, 1) @@ -2176,51 +2586,67 @@ tests = TestSuite("test", [ Assertion(good, ""), ]) -"""), "A", 1) +"""), + "A", + 1 + ) } test("shadowing of external def isn't allowed") { - evalFail( - List(""" + evalFail(List(""" package A external def foo(x: String) -> List[String] def foo(x): x -""")) { case s@PackageError.SourceConverterErrorIn(_, _) => - assert(s.message(Map.empty, Colorize.None) == "in file: , package A\nbind names foo shadow external def\nRegion(57,71)") +""")) { case s @ PackageError.SourceConverterErrorIn(_, _) => + assert( + s.message( + Map.empty, + Colorize.None + ) == "in file: , package A\nbind names foo shadow external def\nRegion(57,71)" + ) () } - evalFail( - List(""" + evalFail(List(""" package A external def foo(x: String) -> List[String] foo = 1 -""")) { case s@PackageError.SourceConverterErrorIn(_, _) => - assert(s.message(Map.empty, Colorize.None) == "in file: , package A\nbind names foo shadow external def\nRegion(57,65)") +""")) { case s @ PackageError.SourceConverterErrorIn(_, _) => + assert( + s.message( + Map.empty, + Colorize.None + ) == "in file: , package A\nbind names foo shadow external def\nRegion(57,65)" + ) () } - evalFail( - List(""" + evalFail(List(""" package A external def foo(x: String) -> List[String] external def foo(x: String) -> List[String] -""")) { case s@PackageError.SourceConverterErrorIn(_, _) => - assert(s.message(Map.empty, Colorize.None) == "in file: , package A\nexternal def: foo defined multiple times\nRegion(21,55)") +""")) { case s @ PackageError.SourceConverterErrorIn(_, _) => + assert( + s.message( + Map.empty, + Colorize.None + ) == "in file: , package A\nexternal def: foo defined multiple times\nRegion(21,55)" + ) () } } test("test meta escape bug") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A struct Build[f] @@ -2234,69 +2660,92 @@ def useList(args: List[Build[File]]): check = useList([]) tests = Assertion(check, "none") -"""), "A", 1) +"""), + "A", + 1 + ) } test("type parameters must be supersets for structs and enums fails") { -evalFail( - List(""" + evalFail(List(""" package Err struct Foo[a](a) main = Foo(1, "2") -""")) { case sce@PackageError.SourceConverterErrorIn(_, _) => - assert(sce.message(Map.empty, Colorize.None) == "in file: , package Err\nFoo found declared: [a], not a superset of [b]\nRegion(14,30)") +""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) => + assert( + sce.message( + Map.empty, + Colorize.None + ) == "in file: , package Err\nFoo found declared: [a], not a superset of [b]\nRegion(14,30)" + ) () } -evalFail( - List(""" + evalFail(List(""" package Err struct Foo[a](a: a, b: b) main = Foo(1, "2") -""")) { case sce@PackageError.SourceConverterErrorIn(_, _) => - assert(sce.message(Map.empty, Colorize.None) == "in file: , package Err\nFoo found declared: [a], not a superset of [a, b]\nRegion(14,39)") +""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) => + assert( + sce.message( + Map.empty, + Colorize.None + ) == "in file: , package Err\nFoo found declared: [a], not a superset of [a, b]\nRegion(14,39)" + ) () } -evalFail( - List(""" + evalFail(List(""" package Err enum Enum[a]: Foo(a) main = Foo(1, "2") -""")) { case sce@PackageError.SourceConverterErrorIn(_, _) => - assert(sce.message(Map.empty, Colorize.None) == "in file: , package Err\nEnum found declared: [a], not a superset of [b]\nRegion(14,34)") +""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) => + assert( + sce.message( + Map.empty, + Colorize.None + ) == "in file: , package Err\nEnum found declared: [a], not a superset of [b]\nRegion(14,34)" + ) () } -evalFail( - List(""" + evalFail(List(""" package Err enum Enum[a]: Foo(a: a), Bar(a: b) main = Foo(1, "2") -""")) { case sce@PackageError.SourceConverterErrorIn(_, _) => - assert(sce.message(Map.empty, Colorize.None) == "in file: , package Err\nEnum found declared: [a], not a superset of [a, b]\nRegion(14,48)") +""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) => + assert( + sce.message( + Map.empty, + Colorize.None + ) == "in file: , package Err\nEnum found declared: [a], not a superset of [a, b]\nRegion(14,48)" + ) () } } test("test duplicate import message") { - evalFail( - List(""" + evalFail(List(""" package Err from Bosatsu/Predef import foldLeft main = 1 -""")) { case sce@PackageError.DuplicatedImport(_) => - assert(sce.message(Map.empty, Colorize.None) == "duplicate import in package Bosatsu/Predef imports foldLeft as foldLeft") +""")) { case sce @ PackageError.DuplicatedImport(_) => + assert( + sce.message( + Map.empty, + Colorize.None + ) == "duplicate import in package Bosatsu/Predef imports foldLeft as foldLeft" + ) () } } @@ -2311,15 +2760,20 @@ main = 1 |main = 1 |""".stripMargin - evalFail(List(pack, pack)) { case sce@PackageError.DuplicatedPackageError(_) => - assert(sce.message(Map.empty, Colorize.None) == "package Err duplicated in 0, 1") - () + evalFail(List(pack, pack)) { + case sce @ PackageError.DuplicatedPackageError(_) => + assert( + sce.message( + Map.empty, + Colorize.None + ) == "package Err duplicated in 0, 1" + ) + () } } test("test bad list pattern message") { - evalFail( - List(""" + evalFail(List(""" package Err x = [1, 2, 3] @@ -2328,16 +2782,20 @@ main = match x: case [*_, *_]: "bad" case _: "still bad" -""")) { case sce@PackageError.TotalityCheckError(_, _) => - assert(sce.message(Map.empty, Colorize.None) == "in file: , package Err\nRegion(36,89)\nmultiple splices in pattern, only one per match allowed") +""")) { case sce @ PackageError.TotalityCheckError(_, _) => + assert( + sce.message( + Map.empty, + Colorize.None + ) == "in file: , package Err\nRegion(36,89)\nmultiple splices in pattern, only one per match allowed" + ) () } } test("test bad string pattern message") { val dollar = '$' - evalFail( - List(s""" + evalFail(List(s""" package Err x = "foo bar" @@ -2346,16 +2804,19 @@ main = match x: case "$dollar{_}$dollar{_}": "bad" case _: "still bad" -""")) { case sce@PackageError.TotalityCheckError(_, _) => +""")) { case sce @ PackageError.TotalityCheckError(_, _) => val dollar = '$' - assert(sce.message(Map.empty, Colorize.None) == - s"in file: , package Err\nRegion(36,91)\ninvalid string pattern: '$dollar{_}$dollar{_}' (adjacent bindings aren't allowed)") + assert( + sce.message(Map.empty, Colorize.None) == + s"in file: , package Err\nRegion(36,91)\ninvalid string pattern: '$dollar{_}$dollar{_}' (adjacent bindings aren't allowed)" + ) () } } test("test parsing type annotations") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A x: Int = 1 @@ -2366,9 +2827,13 @@ y = ( ) tests = Assertion(y.eq_Int(x), "none") -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A x: Int = 1 @@ -2379,11 +2844,15 @@ y = ( ) tests = Assertion(y.eq_Int(x), "none") -"""), "A", 1) +"""), + "A", + 1 + ) } test("improve coverage of typedexpr normalization") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A enum MyBool: T, F @@ -2392,9 +2861,13 @@ main = match T: case F: False tests = Assertion(main, "t1") -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A f = _ -> True @@ -2406,9 +2879,13 @@ tests = Assertion(fn((y = 1 # ignore y _ = y 2)), "t1") -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A def inc(x): @@ -2419,9 +2896,13 @@ def inc(x): z.add(y) tests = Assertion(inc(1).eq_Int(2), "t1") -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A w = 1 @@ -2436,9 +2917,13 @@ def inc(x): case x: x tests = Assertion(inc(1).eq_Int(2), "t1") -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package QueueTest struct Queue[a](front: List[a], back: List[a]) @@ -2448,18 +2933,26 @@ def fold_Queue(Queue(f, b): Queue[a], binit: b, fold_fn: (b, a) -> b) -> b: b.reverse().foldLeft(front, fold_fn) test = Assertion(Queue([1], [2]).fold_Queue(0, add).eq_Int(3), "foldQueue") -"""), "QueueTest", 1) +"""), + "QueueTest", + 1 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A three = (x = 1 y -> x.add(y))(2) test = Assertion(three.eq_Int(3), "let inside apply") -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A substitute = ( @@ -2469,22 +2962,28 @@ substitute = ( ) test = Assertion(substitute.eq_Int(42), "basis substitution") -"""), "A", 1) +"""), + "A", + 1 + ) } test("we can use .( ) to get |> like syntax for lambdas") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A three = 2.(x -> add(x, 1))() test = Assertion(three.eq_Int(3), "let inside apply") -"""), "A", 1) +"""), + "A", + 1 + ) } test("colliding type names cause errors") { - evalFail( - List(s""" + evalFail(List(s""" package Err struct Foo @@ -2492,15 +2991,19 @@ struct Foo struct Foo(x) main = Foo(1) -""")) { case sce@PackageError.SourceConverterErrorIn(_, _) => - assert(sce.message(Map.empty, Colorize.None) == "in file: , package Err\ntype name: Foo defined multiple times\nRegion(14,24)") +""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) => + assert( + sce.message( + Map.empty, + Colorize.None + ) == "in file: , package Err\ntype name: Foo defined multiple times\nRegion(14,24)" + ) () } } test("colliding constructor names cause errors") { - evalFail( - List(s""" + evalFail(List(s""" package Err enum Bar: Foo @@ -2508,23 +3011,33 @@ enum Bar: Foo struct Foo(x) main = Foo(1) -""")) { case sce@PackageError.SourceConverterErrorIn(_, _) => - assert(sce.message(Map.empty, Colorize.None) == "in file: , package Err\nconstructor: Foo defined multiple times\nRegion(14,27)") +""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) => + assert( + sce.message( + Map.empty, + Colorize.None + ) == "in file: , package Err\nconstructor: Foo defined multiple times\nRegion(14,27)" + ) () } } test("non binding top levels work") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A # this is basically a typecheck only _ = add(1, 2) test = Assertion(True, "") -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A # this is basically a typecheck only @@ -2532,9 +3045,13 @@ x = (1, "1") (_, _) = x test = Assertion(True, "") -"""), "A", 1) +"""), + "A", + 1 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package A struct Foo(x, y) @@ -2543,11 +3060,15 @@ x = Foo(1, "1") Foo(_, _) = x test = Assertion(True, "") -"""), "A", 1) +"""), + "A", + 1 + ) } test("recursion check with _ pattern: issue 573") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package VarSet/Recursion enum Thing: @@ -2559,7 +3080,10 @@ def bar(y, _: String, x): Thing2(i, t): bar(i, "boom", t) test = Assertion(True, "") -"""), "VarSet/Recursion", 1) +"""), + "VarSet/Recursion", + 1 + ) } test("recursion check with shadowing") { @@ -2576,8 +3100,13 @@ def bar(y, _: String, x): Thing2(i, t): bar(i, "boom", t) test = Assertion(True, "") -""")) { case re@PackageError.RecursionError(_, _) => - assert(re.message(Map.empty, Colorize.None) == "in file: , package S\nrecur not on an argument to the def of bar, args: (y, _: String, x)\nRegion(107,165)\n") +""")) { case re @ PackageError.RecursionError(_, _) => + assert( + re.message( + Map.empty, + Colorize.None + ) == "in file: , package S\nrecur not on an argument to the def of bar, args: (y, _: String, x)\nRegion(107,165)\n" + ) () } } @@ -2590,8 +3119,13 @@ out = match (1,2): case (a, a): a test = Assertion(True, "") -""")) { case sce@PackageError.SourceConverterErrorIn(_, _) => - assert(sce.message(Map.empty, Colorize.None) == "in file: , package Foo\nrepeated bindings in pattern: a\nRegion(48,49)") +""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) => + assert( + sce.message( + Map.empty, + Colorize.None + ) == "in file: , package Foo\nrepeated bindings in pattern: a\nRegion(48,49)" + ) () } evalFail(List(""" @@ -2602,11 +3136,17 @@ out = match [(1,2), (1, 0)]: case _: 0 test = Assertion(True, "") -""")) { case sce@PackageError.SourceConverterErrorIn(_, _) => - assert(sce.message(Map.empty, Colorize.None) == "in file: , package Foo\nrepeated bindings in pattern: a\nRegion(68,69)") +""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) => + assert( + sce.message( + Map.empty, + Colorize.None + ) == "in file: , package Foo\nrepeated bindings in pattern: a\nRegion(68,69)" + ) () } - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package Foo out = match [(1,2), (1, 0)]: @@ -2614,11 +3154,15 @@ out = match [(1,2), (1, 0)]: case _: 0 test = Assertion(out.eq_Int(1), "") -"""), "Foo", 1) +"""), + "Foo", + 1 + ) } test("test some complex list patterns, issue 574") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package Foo out = match [(True, 2), (True, 0)]: @@ -2627,7 +3171,10 @@ out = match [(True, 2), (True, 0)]: case _: -1 test = Assertion(out.eq_Int(0), "") -"""), "Foo", 1) +"""), + "Foo", + 1 + ) } test("unknown type constructor message is good. issue 653") { @@ -2638,8 +3185,13 @@ struct Bar(baz: Either[Int, String]) test = Assertion(True, "") -""")) { case sce@PackageError.SourceConverterErrorIn(_, _) => - assert(sce.message(Map.empty, Colorize.None) == "in file: , package Foo\nunknown type: Either\nRegion(14,50)") +""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) => + assert( + sce.message( + Map.empty, + Colorize.None + ) == "in file: , package Foo\nunknown type: Either\nRegion(14,50)" + ) () } } @@ -2653,7 +3205,7 @@ export FooE() enum FooE: Foo1, Foo2 """ :: -""" + """ package Bar from Foo import Foo1, Foo2 @@ -2664,7 +3216,10 @@ m = match x: case Foo2: False test = Assertion(m, "x matches Foo1") -""" :: Nil, "Bar", 1) +""" :: Nil, + "Bar", + 1 + ) } test("its an error to export a value and not its type. issue 782") { @@ -2683,13 +3238,19 @@ from Foo import bar x = bar """ :: Nil) { case sce => - assert(sce.message(Map.empty, Colorize.None) == "in export bar of type Foo::Bar has an unexported (private) type.") + assert( + sce.message( + Map.empty, + Colorize.None + ) == "in export bar of type Foo::Bar has an unexported (private) type." + ) () } } test("test def with type params") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package Foo def foo[a](a: a) -> a: @@ -2699,7 +3260,10 @@ def foo[a](a: a) -> a: and_again(again(x)) test = Assertion(foo(True), "") -"""), "Foo", 1) +"""), + "Foo", + 1 + ) evalFail(List(""" package Foo @@ -2710,8 +3274,13 @@ def foo[a](a: a) -> a: def and_again[b](x: a): x and_again(again(x)) -""")) { case sce@PackageError.SourceConverterErrorIn(_, _) => - assert(sce.message(Map.empty, Colorize.None) == "in file: , package Foo\nand_again found declared types: [b], not a subset of [a]\nRegion(71,118)") +""")) { case sce @ PackageError.SourceConverterErrorIn(_, _) => + assert( + sce.message( + Map.empty, + Colorize.None + ) == "in file: , package Foo\nand_again found declared types: [b], not a subset of [a]\nRegion(71,118)" + ) () } } @@ -2729,12 +3298,12 @@ struct RecordGetter[shape, t]( def get[shape](sh: shape[RecordValue], RecordGetter(getter): RecordGetter[shape, t]) -> t: RecordValue(result) = sh.getter() result -""")) { case PackageError.TypeErrorIn(_, _) => () - } +""")) { case PackageError.TypeErrorIn(_, _) => () } } test("test quicklook example") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package Foo def f(fn: forall a. List[a] -> List[a]) -> Int: @@ -2767,7 +3336,10 @@ pair = Pair1(single_id1, single_id2) comp = x -> f(g(x)) test = Assertion(True, "") -"""), "Foo", 1) +"""), + "Foo", + 1 + ) } test("ill-kinded structs point to the right region") { @@ -2776,12 +3348,14 @@ test = Assertion(True, "") package Foo struct Foo(a: f[a], b: f) -""")) { case kie@PackageError.KindInferenceError(_, _, _) => - assert(kie.message(Map.empty, Colorize.None) == - """in file: , package Foo +""")) { case kie @ PackageError.KindInferenceError(_, _, _) => + assert( + kie.message(Map.empty, Colorize.None) == + """in file: , package Foo shape error: expected kind(f) and * to match in the constructor Foo -Region(14,39)""") +Region(14,39)""" + ) () } @@ -2789,17 +3363,20 @@ Region(14,39)""") package Foo struct Foo[a: *](a: a[Int]) -""")) { case kie@PackageError.KindInferenceError(_, _, _) => - assert(kie.message(Map.empty, Colorize.None) == - """in file: , package Foo shape error: expected * -> ? but found * in the constructor Foo inside type a[Bosatsu/Predef::Int] +""")) { case kie @ PackageError.KindInferenceError(_, _, _) => + assert( + kie.message(Map.empty, Colorize.None) == + """in file: , package Foo shape error: expected * -> ? but found * in the constructor Foo inside type a[Bosatsu/Predef::Int] -Region(14,41)""") +Region(14,41)""" + ) () } } test("example from issue #264") { - runBosatsuTest(""" + runBosatsuTest( + """ package SubsumeTest def lengths(l1: List[Int], l2: List[String], maybeFn: Option[forall tt. List[tt] -> Int]): @@ -2818,7 +3395,10 @@ x = match []: case [h, *_]: (h: forall a. a) test = Assertion(lengths([], [], None) matches 0, "test") - """ :: Nil, "SubsumeTest", 1) + """ :: Nil, + "SubsumeTest", + 1 + ) } test("ill kinded code examples") { @@ -2832,9 +3412,10 @@ struct Id(a) # this code could run if we ignored kinds def makeFoo(v: Int): Foo(Id(v)) -""")) { case kie@PackageError.TypeErrorIn(_, _) => - assert(kie.message(Map.empty, Colorize.None) == - """in file: , package Foo +""")) { case kie @ PackageError.TypeErrorIn(_, _) => + assert( + kie.message(Map.empty, Colorize.None) == + """in file: , package Foo kind error: the type: ?1 of kind: * -> * at: Region(183,188) @@ -2854,18 +3435,19 @@ struct Id(a) # this code could run if we ignored kinds def makeFoo(v: Int) -> Foo[Id, Int]: Foo(Id(v)) -""")) { case kie@PackageError.TypeErrorIn(_, _) => - assert(kie.message(Map.empty, Colorize.None) == - """in file: , package Foo +""")) { case kie @ PackageError.TypeErrorIn(_, _) => + assert( + kie.message(Map.empty, Colorize.None) == + """in file: , package Foo kind error: the type: Foo::Foo[Foo::Id] is invalid because the left Foo::Foo has kind ((* -> *) -> *) -> (* -> *) -> * and the right Foo::Id has kind +* -> * but left cannot accept the kind of the right: Region(195,205)""" ) () } - + } test("print a decent message when arguments are omitted") { - evalFail(List(""" + evalFail(List(""" package QS def quick_sort0(cmp, left, right): @@ -2880,15 +3462,17 @@ def quick_sort0(cmp, left, right): # we accidentally omit bigger below bigs = quick_sort0(cmp, tail) [*smalls, *bigs] -""")) { case kie@PackageError.TypeErrorIn(_, _) => - assert(kie.message(Map.empty, Colorize.None) == - """in file: , package QS +""")) { case kie @ PackageError.TypeErrorIn(_, _) => + assert( + kie.message(Map.empty, Colorize.None) == + """in file: , package QS type error: expected type Bosatsu/Predef::Fn3[(?43, ?41) -> Bosatsu/Predef::Comparison] to be the same as type Bosatsu/Predef::Fn2 hint: the first type is a function with 3 arguments and the second is a function with 2 arguments. -Region(396,450)""") +Region(396,450)""" + ) () } - + } test("error early on a bad type in a recursive function") { @@ -2903,7 +3487,7 @@ def toInt(n: N, acc: Int) -> Int: case S(n): toInt(n, "foo") """ - evalFail(List(testCode)) { case kie@PackageError.TypeErrorIn(_, _) => + evalFail(List(testCode)) { case kie @ PackageError.TypeErrorIn(_, _) => val message = kie.message(Map.empty, Colorize.None) assert(message.contains("Region(122,127)")) val badRegion = testCode.substring(122, 127) @@ -2913,7 +3497,8 @@ def toInt(n: N, acc: Int) -> Int: } test("declaring a generic parameter works fine") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package Generic enum NEList[a: +*]: @@ -2925,9 +3510,13 @@ def head(nel: NEList[a]) -> a: case One(a) | Many(a, _): a test = Assertion(head(One(True)), "") -"""), "Generic", 1) +"""), + "Generic", + 1 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package Generic enum NEList[a: +*]: @@ -2939,10 +3528,14 @@ def head[a](nel: NEList[a]) -> a: case One(a) | Many(a, _): a test = Assertion(head(One(True)), "") -"""), "Generic", 1) +"""), + "Generic", + 1 + ) // With recursion - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package Generic enum NEList[a: +*]: @@ -2955,10 +3548,14 @@ def last(nel: NEList[a]) -> a: case Many(_, tail): last(tail) test = Assertion(last(One(True)), "") -"""), "Generic", 1) +"""), + "Generic", + 1 + ) // With recursion - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package Generic enum NEList[a: +*]: @@ -2971,7 +3568,10 @@ def last[a](nel: NEList[a]) -> a: case Many(_, tail): last(tail) test = Assertion(last(One(True)), "") -"""), "Generic", 1) +"""), + "Generic", + 1 + ) } test("support polymorphic recursion") { @@ -2990,8 +3590,10 @@ def poly_rec(count: Nat, a: a) -> a: b test = Assertion(True, "") -""") - , "PolyRec", 1) +"""), + "PolyRec", + 1 + ) runBosatsuTest( List(""" @@ -3013,8 +3615,10 @@ def call(a): poly_rec(NZero, a) test = Assertion(True, "") -""") - , "PolyRec", 1) +"""), + "PolyRec", + 1 + ) } test("recursion on continuations") { @@ -3038,7 +3642,10 @@ def loop(box: Cont) -> Int: v = loop(b) main = v -"""), "A", VInt(2)) +"""), + "A", + VInt(2) + ) // Generic version evalTest( @@ -3059,10 +3666,13 @@ def loop[a](box: Cont[a]) -> a: loopgen: forall a. Cont[a] -> a = loop b: Cont[Int] = Item(1).map(x -> x.add(1)) main: Int = loop(b) -"""), "A", VInt(2)) +"""), + "A", + VInt(2) + ) - // this example also exercises polymorphic recursion - evalTest( + // this example also exercises polymorphic recursion + evalTest( List(""" package A enum Box[a: +*]: @@ -3081,6 +3691,9 @@ def loop[a](box: Box[a]) -> a: v = loop(b) main = v -"""), "A", VInt(1)) +"""), + "A", + VInt(1) + ) } } diff --git a/core/src/test/scala/org/bykn/bosatsu/FreeVarTest.scala b/core/src/test/scala/org/bykn/bosatsu/FreeVarTest.scala index 5db08f8e6..f2209da8a 100644 --- a/core/src/test/scala/org/bykn/bosatsu/FreeVarTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/FreeVarTest.scala @@ -1,14 +1,17 @@ package org.bykn.bosatsu -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.scalatest.funsuite.AnyFunSuite class FreeVarTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = PropertyCheckConfiguration(minSuccessful = 1000) - //PropertyCheckConfiguration(minSuccessful = 300) - //PropertyCheckConfiguration(minSuccessful = 5) + // PropertyCheckConfiguration(minSuccessful = 300) + // PropertyCheckConfiguration(minSuccessful = 5) def assertFreeVars(stmt: String, vars: List[String]) = Statement.parser.parseAll(stmt) match { @@ -25,14 +28,18 @@ class FreeVarTest extends AnyFunSuite { assertFreeVars("""y = 1""", Nil) assertFreeVars("""external foo: Int""", Nil) assertFreeVars("""def foo(x): y""", List("y")) - assertFreeVars("""def foo(x): + assertFreeVars( + """def foo(x): y = x - y""", Nil) + y""", + Nil + ) } test("freeVars is a subset of allNames") { forAll(Generators.genStatement(3)) { stmt => - Statement.valuesOf(stmt :: Nil) + Statement + .valuesOf(stmt :: Nil) .foreach { v => assert(v.freeVars.subsetOf(v.allNames)) } diff --git a/core/src/test/scala/org/bykn/bosatsu/Gen.scala b/core/src/test/scala/org/bykn/bosatsu/Gen.scala index ab203f898..c05cec543 100644 --- a/core/src/test/scala/org/bykn/bosatsu/Gen.scala +++ b/core/src/test/scala/org/bykn/bosatsu/Gen.scala @@ -43,7 +43,10 @@ object Generators { for { e <- Gen.lzy(typeRefGen) cnt <- Gen.choose(1, 3) - args <- Gen.listOfN(cnt, Gen.zip(typeRefVarGen, Gen.option(NTypeGen.genKind))) + args <- Gen.listOfN( + cnt, + Gen.zip(typeRefVarGen, Gen.option(NTypeGen.genKind)) + ) nel = NonEmptyList.fromListUnsafe(args) } yield TypeRef.TypeForAll(nel, e) @@ -55,7 +58,10 @@ object Generators { multiGen = Gen.oneOf(Operators.multiToks) ms <- Gen.listOfN(c, multiGen) asStr = ms.mkString - res <- (if ((asStr != "<-") && (asStr != "->")) Gen.const(Identifier.Operator(asStr)) else multi) + res <- + (if ((asStr != "<-") && (asStr != "->")) + Gen.const(Identifier.Operator(asStr)) + else multi) } yield res Gen.frequency((4, sing), (1, multi)) @@ -63,11 +69,20 @@ object Generators { val bindIdentGen: Gen[Identifier.Bindable] = Gen.frequency( - (10, lowerIdent.filter { n => !Declaration.keywords(n) }.map { n => Identifier.Name(n) }), + ( + 10, + lowerIdent.filter { n => !Declaration.keywords(n) }.map { n => + Identifier.Name(n) + } + ), (1, opGen), - (1, Arbitrary.arbitrary[String].map { s => - Identifier.Backticked(s) - })) + ( + 1, + Arbitrary.arbitrary[String].map { s => + Identifier.Backticked(s) + } + ) + ) lazy val typeRefGen: Gen[TypeRef] = { import TypeRef._ @@ -93,10 +108,16 @@ object Generators { Gen.frequency( (5, tvar), (5, tname), - (1, Gen.zip(Gen.lzy(smallNonEmptyList(typeRefGen, 4)), Gen.lzy(typeRefGen)).map { case (a, b) => TypeArrow(a, b) }), + ( + 1, + Gen + .zip(Gen.lzy(smallNonEmptyList(typeRefGen, 4)), Gen.lzy(typeRefGen)) + .map { case (a, b) => TypeArrow(a, b) } + ), (1, tLambda), (1, tTup), - (1, tApply)) + (1, tApply) + ) } implicit val shrinkTypeRef: Shrink[TypeRef] = @@ -124,7 +145,9 @@ object Generators { }) def commentGen[T](dec: Gen[T]): Gen[CommentStatement[T]] = { - def cleanNewLine(s: String): String = s.map { c => if (c == '\n') ' ' else c } + def cleanNewLine(s: String): String = s.map { c => + if (c == '\n') ' ' else c + } for { cs <- nonEmpty(Arbitrary.arbitrary[String]) t <- dec @@ -139,7 +162,7 @@ object Generators { def argToPat(arg: (Identifier.Bindable, Option[TypeRef])): Pattern.Parsed = arg match { - case (bn, None) => Pattern.Var(bn) + case (bn, None) => Pattern.Var(bn) case (bn, Some(t)) => Pattern.Annotation(Pattern.Var(bn), t) } @@ -150,34 +173,63 @@ object Generators { tpes <- smallList(Gen.zip(typeRefVarGen, Gen.option(NTypeGen.genKind))) retType <- Gen.option(typeRefGen) body <- dec - } yield DefStatement(name, NonEmptyList.fromList(tpes), args.map(_.map(argToPat)), retType, body) + } yield DefStatement( + name, + NonEmptyList.fromList(tpes), + args.map(_.map(argToPat)), + retType, + body + ) - def genSpliceOrItem[A](spliceGen: Gen[A], itemGen: Gen[A]): Gen[ListLang.SpliceOrItem[A]] = - Gen.oneOf(spliceGen.map(ListLang.SpliceOrItem.Splice(_)), - itemGen.map(ListLang.SpliceOrItem.Item(_))) + def genSpliceOrItem[A]( + spliceGen: Gen[A], + itemGen: Gen[A] + ): Gen[ListLang.SpliceOrItem[A]] = + Gen.oneOf( + spliceGen.map(ListLang.SpliceOrItem.Splice(_)), + itemGen.map(ListLang.SpliceOrItem.Item(_)) + ) - def genListLangCons[A](spliceGen: Gen[A], itemGen: Gen[A]): Gen[ListLang.Cons[ListLang.SpliceOrItem, A]] = { - Gen.choose(0, 5) + def genListLangCons[A]( + spliceGen: Gen[A], + itemGen: Gen[A] + ): Gen[ListLang.Cons[ListLang.SpliceOrItem, A]] = { + Gen + .choose(0, 5) .flatMap(Gen.listOfN(_, genSpliceOrItem(spliceGen, itemGen))) .map(ListLang.Cons(_)) } - def genListLangDictCons[A](itemGen: Gen[A]): Gen[ListLang.Cons[ListLang.KVPair, A]] = { - Gen.choose(0, 5) - .flatMap(Gen.listOfN(_, - Gen.zip(itemGen, itemGen).map { case (k, v) => ListLang.KVPair(k, v) })) + def genListLangDictCons[A]( + itemGen: Gen[A] + ): Gen[ListLang.Cons[ListLang.KVPair, A]] = { + Gen + .choose(0, 5) + .flatMap( + Gen.listOfN( + _, + Gen.zip(itemGen, itemGen).map { case (k, v) => ListLang.KVPair(k, v) } + ) + ) .map(ListLang.Cons(_)) } def genStringDecl(dec0: Gen[NonBinding]): Gen[Declaration.StringDecl] = { val item = Gen.oneOf( - Arbitrary.arbitrary[String].filter(_.length > 1).map { s => Right((emptyRegion, s)) }, - dec0.map(Left(_))) + Arbitrary.arbitrary[String].filter(_.length > 1).map { s => + Right((emptyRegion, s)) + }, + dec0.map(Left(_)) + ) - def removeAdj[A](nea: NonEmptyList[A])(fn: (A, A) => Boolean): NonEmptyList[A] = + def removeAdj[A]( + nea: NonEmptyList[A] + )(fn: (A, A) => Boolean): NonEmptyList[A] = nea match { - case NonEmptyList(a1, a2 :: tail) if fn(a1, a2) => removeAdj(NonEmptyList(a2, tail))(fn) - case NonEmptyList(a1, a2 :: tail) => NonEmptyList(a1, removeAdj(NonEmptyList(a2, tail))(fn).toList) + case NonEmptyList(a1, a2 :: tail) if fn(a1, a2) => + removeAdj(NonEmptyList(a2, tail))(fn) + case NonEmptyList(a1, a2 :: tail) => + NonEmptyList(a1, removeAdj(NonEmptyList(a2, tail))(fn).toList) case ne1 => ne1 } @@ -198,8 +250,8 @@ object Generators { def listGen(dec0: Gen[NonBinding]): Gen[Declaration.ListDecl] = { lazy val filterFn: NonBinding => Boolean = { - case Declaration.IfElse(_, _) => false - case Declaration.Match(_, _, _) => false + case Declaration.IfElse(_, _) => false + case Declaration.Match(_, _, _) => false case Declaration.Lambda(_, body: NonBinding) => filterFn(body) case Declaration.Apply(f, args, _) => filterFn(f) && args.forall(filterFn) @@ -211,10 +263,11 @@ object Generators { // TODO we can't parse if since we get confused about it being a ternary expression val pat = genPattern(1, useUnion = true) - val comp = Gen.zip(genSpliceOrItem(dec, dec), pat, dec, Gen.option(dec)) + val comp = Gen + .zip(genSpliceOrItem(dec, dec), pat, dec, Gen.option(dec)) .map { case (a, b, c0, _) => val c = c0 match { - case tern@Declaration.Ternary(_, _, _) => + case tern @ Declaration.Ternary(_, _, _) => Declaration.Parens(tern)(emptyRegion) case not => not } @@ -226,10 +279,10 @@ object Generators { def dictGen(dec0: Gen[NonBinding]): Gen[Declaration.DictDecl] = { lazy val filterFn: NonBinding => Boolean = { - case Declaration.Annotation(_, _) => false - case Declaration.IfElse(_, _) => false - case Declaration.Match(_, _, _) => false - case Declaration.ApplyOp(_, _, _) => false + case Declaration.Annotation(_, _) => false + case Declaration.IfElse(_, _) => false + case Declaration.Match(_, _, _) => false + case Declaration.ApplyOp(_, _, _) => false case Declaration.Lambda(_, body: NonBinding) => filterFn(body) case Declaration.Apply(f, args, _) => filterFn(f) && args.forall(filterFn) @@ -241,10 +294,11 @@ object Generators { // TODO we can't parse if since we get confused about it being a ternary expression val pat = genPattern(1, useUnion = true) - val comp = Gen.zip(dec, dec, pat, dec, Gen.option(dec)) + val comp = Gen + .zip(dec, dec, pat, dec, Gen.option(dec)) .map { case (k, v, b, c0, _) => val c = c0 match { - case tern@Declaration.Ternary(_, _, _) => + case tern @ Declaration.Ternary(_, _, _) => Declaration.Parens(tern)(emptyRegion) case not => not } @@ -258,7 +312,9 @@ object Generators { decl.map { case n: Declaration.NonBinding => n match { - case v@(Declaration.Var(_) | Declaration.Parens(_) | Declaration.Apply(_, _, _)) => v + case v @ (Declaration.Var(_) | Declaration.Parens(_) | + Declaration.Apply(_, _, _)) => + v case notVar => Declaration.Parens(notVar)(emptyRegion) } case notVar => Declaration.Parens(notVar)(emptyRegion) @@ -270,18 +326,26 @@ object Generators { def isVar(d: Declaration): Boolean = d match { case Declaration.Var(_) => true - case _ => false + case _ => false } - def applyGen(fnGen: Gen[NonBinding], arg: Gen[NonBinding], dotApplyGen: Gen[Boolean]): Gen[Declaration.Apply] = { + def applyGen( + fnGen: Gen[NonBinding], + arg: Gen[NonBinding], + dotApplyGen: Gen[Boolean] + ): Gen[Declaration.Apply] = { import Declaration._ - Gen.lzy(for { - fn <- fnGen - dotApply <- dotApplyGen - useDot = dotApply && isVar(fn) // f.bar needs the fn to be a var - argsGen = if (useDot) arg.map(NonEmptyList.one(_)) else smallNonEmptyList(arg, 8) - args <- argsGen - } yield Apply(fn, args, ApplyKind.Parens)(emptyRegion)) // TODO this should pass if we use `foo.bar(a, b)` syntax + Gen.lzy( + for { + fn <- fnGen + dotApply <- dotApplyGen + useDot = dotApply && isVar(fn) // f.bar needs the fn to be a var + argsGen = + if (useDot) arg.map(NonEmptyList.one(_)) + else smallNonEmptyList(arg, 8) + args <- argsGen + } yield Apply(fn, args, ApplyKind.Parens)(emptyRegion) + ) // TODO this should pass if we use `foo.bar(a, b)` syntax } def applyOpGen(arg: Gen[NonBinding]): Gen[Declaration.ApplyOp] = @@ -299,20 +363,31 @@ object Generators { ApplyOp(protect(l), op, protect(r)) } - def bindGen[A, T](patGen: Gen[A], dec: Gen[NonBinding], tgen: Gen[T]): Gen[BindingStatement[A, NonBinding, T]] = - Gen.zip(patGen, dec, tgen) + def bindGen[A, T]( + patGen: Gen[A], + dec: Gen[NonBinding], + tgen: Gen[T] + ): Gen[BindingStatement[A, NonBinding, T]] = + Gen + .zip(patGen, dec, tgen) .map { case (b, value, in) => BindingStatement(b, value, in) } - def leftApplyGen(patGen: Gen[Pattern.Parsed], dec: Gen[NonBinding], bodyGen: Gen[Declaration]): Gen[Declaration.LeftApply] = - Gen.zip(patGen, dec, padding(bodyGen)) + def leftApplyGen( + patGen: Gen[Pattern.Parsed], + dec: Gen[NonBinding], + bodyGen: Gen[Declaration] + ): Gen[Declaration.LeftApply] = + Gen + .zip(patGen, dec, padding(bodyGen)) .map { case (p, value, in) => Declaration.LeftApply(p, emptyRegion, value, in) } def padding[T](tgen: Gen[T], min: Int = 0): Gen[Padding[T]] = - Gen.zip(Gen.choose(min, 10), tgen) + Gen + .zip(Gen.choose(min, 10), tgen) .map { case (e, t) => Padding(e, t) } def indented[T](tgen: Gen[T]): Gen[Indented[T]] = @@ -325,34 +400,37 @@ object Generators { for { args <- nonEmpty(bindIdentGen) body <- bodyGen - } yield Declaration.Lambda(args.map(Pattern.Var(_)), body)(emptyRegion) + } yield Declaration.Lambda(args.map(Pattern.Var(_)), body)(emptyRegion) def optIndent[A](genA: Gen[A]): Gen[OptIndent[A]] = { val indentation = Gen.choose(1, 10) indentation.flatMap { i => - - // TODO support parsing if foo: bar - //Gen.oneOf( - padding(genA.map(Indented(i, _)), min = 1).map(OptIndent.notSame(_)) - //, - //bodyGen.map(Left(_): OptIndent[Declaration])) + // TODO support parsing if foo: bar + // Gen.oneOf( + padding(genA.map(Indented(i, _)), min = 1).map(OptIndent.notSame(_)) + // , + // bodyGen.map(Left(_): OptIndent[Declaration])) } } - def ifElseGen(argGen0: Gen[NonBinding], bodyGen: Gen[Declaration]): Gen[Declaration.IfElse] = { + def ifElseGen( + argGen0: Gen[NonBinding], + bodyGen: Gen[Declaration] + ): Gen[Declaration.IfElse] = { import Declaration._ // args can't have raw annotations: val argGen = argGen0.map { - case ann@Annotation(_, _) => Parens(ann)(emptyRegion) - case notAnn => notAnn + case ann @ Annotation(_, _) => Parens(ann)(emptyRegion) + case notAnn => notAnn } val padBody = optIndent(bodyGen) val genIf: Gen[(NonBinding, OptIndent[Declaration])] = Gen.zip(argGen, padBody) - Gen.zip(nonEmptyN(genIf, 2), padBody) + Gen + .zip(nonEmptyN(genIf, 2), padBody) .map { case (ifs, elsec) => IfElse(ifs, elsec)(emptyRegion) } } @@ -360,14 +438,15 @@ object Generators { import Declaration._ val argGen = argGen0.map { - case lam@Lambda(_, _) => Parens(lam)(emptyRegion) - case ife@IfElse(_, _) => Parens(ife)(emptyRegion) - case tern@Ternary(_, _, _) => Parens(tern)(emptyRegion) - case matches@Matches(_, _) => Parens(matches)(emptyRegion) - case m@Match(_, _, _) => Parens(m)(emptyRegion) - case not => not + case lam @ Lambda(_, _) => Parens(lam)(emptyRegion) + case ife @ IfElse(_, _) => Parens(ife)(emptyRegion) + case tern @ Ternary(_, _, _) => Parens(tern)(emptyRegion) + case matches @ Matches(_, _) => Parens(matches)(emptyRegion) + case m @ Match(_, _, _) => Parens(m)(emptyRegion) + case not => not } - Gen.zip(argGen, argGen, argGen) + Gen + .zip(argGen, argGen, argGen) .map { case (t, c, f) => Ternary(t, c, f) } } @@ -379,25 +458,33 @@ object Generators { def toArg(p: Pattern.Parsed): Gen[Pattern.StructKind.Style.FieldKind] = p match { case Pattern.Var(b: Identifier.Bindable) => - Gen.oneOf(Gen.const(Pattern.StructKind.Style.FieldKind.Implicit(b)), - Gen.oneOf(bindIdentGen, Gen.const(b)) + Gen.oneOf( + Gen.const(Pattern.StructKind.Style.FieldKind.Implicit(b)), + Gen + .oneOf(bindIdentGen, Gen.const(b)) .map(Pattern.StructKind.Style.FieldKind.Explicit(_)) - ) + ) case Pattern.Annotation(p, _) => toArg(p) - case _ => + case _ => // if we don't have a var, we can't omit the key bindIdentGen.map(Pattern.StructKind.Style.FieldKind.Explicit(_)) } - lazy val args = tail.foldLeft(toArg(h) - .map(NonEmptyList.one)) { case (args, a) => - Gen.zip(args, toArg(a)).map { case (args, a) => NonEmptyList(a, args.toList) } + lazy val args = tail + .foldLeft( + toArg(h) + .map(NonEmptyList.one) + ) { case (args, a) => + Gen.zip(args, toArg(a)).map { case (args, a) => + NonEmptyList(a, args.toList) + } } .map(_.reverse) Gen.oneOf( Gen.const(Pattern.StructKind.Style.TupleLike), - Gen.lzy(args.map(Pattern.StructKind.Style.RecordLike(_)))) + Gen.lzy(args.map(Pattern.StructKind.Style.RecordLike(_))) + ) } def genStructKind(args: List[Pattern.Parsed]): Gen[Pattern.StructKind] = @@ -408,7 +495,8 @@ object Generators { }, Gen.zip(consIdentGen, genStyle(args)).map { case (n, s) => Pattern.StructKind.NamedPartial(n, s) - }) + } + ) def genPattern(depth: Int, useUnion: Boolean = true): Gen[Pattern.Parsed] = genPatternGen( @@ -416,18 +504,29 @@ object Generators { typeRefGen, depth, useUnion, - useAnnotation = false) + useAnnotation = false + ) - def genPatternGen[N, T](genName: List[Pattern[N, T]] => Gen[N], genT: Gen[T], depth: Int, useUnion: Boolean, useAnnotation: Boolean): Gen[Pattern[N, T]] = { - val recurse = Gen.lzy(genPatternGen(genName, genT, depth - 1, useUnion, useAnnotation)) + def genPatternGen[N, T]( + genName: List[Pattern[N, T]] => Gen[N], + genT: Gen[T], + depth: Int, + useUnion: Boolean, + useAnnotation: Boolean + ): Gen[Pattern[N, T]] = { + val recurse = + Gen.lzy(genPatternGen(genName, genT, depth - 1, useUnion, useAnnotation)) val genVar = bindIdentGen.map(Pattern.Var(_)) val genWild = Gen.const(Pattern.WildCard) val genLitPat = genLit.map(Pattern.Literal(_)) if (depth <= 0) Gen.oneOf(genVar, genWild, genLitPat) else { - val genNamed = Gen.zip(bindIdentGen, recurse).map { case (n, p) => Pattern.Named(n, p) } - val genTyped = Gen.zip(recurse, genT) + val genNamed = Gen.zip(bindIdentGen, recurse).map { case (n, p) => + Pattern.Named(n, p) + } + val genTyped = Gen + .zip(recurse, genT) .map { case (p, t) => Pattern.Annotation(p, t) } lazy val genStrPat: Gen[Pattern.StrPat] = { @@ -437,20 +536,26 @@ object Generators { Gen.oneOf( lowerIdent.map(Pattern.StrPart.LitStr(_)), bindIdentGen.map(Pattern.StrPart.NamedStr(_)), - Gen.const(Pattern.StrPart.WildStr)) + Gen.const(Pattern.StrPart.WildStr) + ) def isWild(p: Pattern.StrPart): Boolean = p match { case Pattern.StrPart.LitStr(_) => false - case _ => true + case _ => true } - def makeValid(nel: NonEmptyList[Pattern.StrPart]): NonEmptyList[Pattern.StrPart] = + def makeValid( + nel: NonEmptyList[Pattern.StrPart] + ): NonEmptyList[Pattern.StrPart] = nel match { case NonEmptyList(_, Nil) => nel case NonEmptyList(h1, h2 :: t) if isWild(h1) && isWild(h2) => makeValid(NonEmptyList(h2, t)) - case NonEmptyList(Pattern.StrPart.LitStr(h1), Pattern.StrPart.LitStr(h2) :: t) => + case NonEmptyList( + Pattern.StrPart.LitStr(h1), + Pattern.StrPart.LitStr(h2) :: t + ) => makeValid(NonEmptyList(Pattern.StrPart.LitStr(h1 + h2), t)) case NonEmptyList(h1, h2 :: t) => NonEmptyList(h1, makeValid(NonEmptyList(h2, t)).toList) @@ -464,7 +569,7 @@ object Generators { } yield notStr } - val genStruct = for { + val genStruct = for { cnt <- Gen.choose(0, 6) args <- Gen.listOfN(cnt, recurse) nm <- genName(args) @@ -473,38 +578,50 @@ object Generators { def makeOneSplice(ps: List[Pattern.ListPart[Pattern[N, T]]]) = { val sz = ps.size if (sz == 0) Gen.const(ps) - else Gen.choose(0, sz - 1).flatMap { idx => - val splice = Gen.oneOf( - Gen.const(Pattern.ListPart.WildList), - bindIdentGen.map { v => Pattern.ListPart.NamedList(v) }) - - splice.map { v => ps.updated(idx, v) } - } + else + Gen.choose(0, sz - 1).flatMap { idx => + val splice = Gen.oneOf( + Gen.const(Pattern.ListPart.WildList), + bindIdentGen.map { v => Pattern.ListPart.NamedList(v) } + ) + + splice.map { v => ps.updated(idx, v) } + } } val genListItem: Gen[Pattern.ListPart[Pattern[N, T]]] = recurse.map(Pattern.ListPart.Item(_)) - val genList = Gen.choose(0, 5) + val genList = Gen + .choose(0, 5) .flatMap(Gen.listOfN(_, genListItem)) .flatMap { ls => - Gen.oneOf(true, false) + Gen + .oneOf(true, false) .flatMap { - case true => Gen.const(ls) + case true => Gen.const(ls) case false => makeOneSplice(ls) } } .map(Pattern.ListPat(_)) - val genUnion = Gen.choose(0, 2) + val genUnion = Gen + .choose(0, 2) .flatMap { sz => Gen.zip(recurse, recurse, Gen.listOfN(sz, recurse)) } - .map { - case (h0, h1, tail) => - Pattern.union(h0, h1 :: tail) + .map { case (h0, h1, tail) => + Pattern.union(h0, h1 :: tail) } val tailGens: List[Gen[Pattern[N, T]]] = - List(genVar, genWild, genNamed, genStrPat, genLitPat, genStruct, genList) + List( + genVar, + genWild, + genNamed, + genStrPat, + genLitPat, + genStruct, + genList + ) val withU = if (useUnion) genUnion :: tailGens else tailGens val withT = (if (useAnnotation) genTyped :: withU else withU).toArray @@ -513,20 +630,33 @@ object Generators { } } - def genCompiledPattern(depth: Int, useUnion: Boolean = true, useAnnotation: Boolean = true): Gen[Pattern[(PackageName, Identifier.Constructor), rankn.Type]] = + def genCompiledPattern( + depth: Int, + useUnion: Boolean = true, + useAnnotation: Boolean = true + ): Gen[Pattern[(PackageName, Identifier.Constructor), rankn.Type]] = genPatternGen( - { (_: List[Pattern[(PackageName, Identifier.Constructor), rankn.Type]]) => Gen.zip(packageNameGen, consIdentGen) }, - NTypeGen.genDepth03, depth, useUnion = useUnion, useAnnotation = useAnnotation) + { (_: List[Pattern[(PackageName, Identifier.Constructor), rankn.Type]]) => + Gen.zip(packageNameGen, consIdentGen) + }, + NTypeGen.genDepth03, + depth, + useUnion = useUnion, + useAnnotation = useAnnotation + ) - def matchGen(argGen0: Gen[NonBinding], bodyGen: Gen[Declaration]): Gen[Declaration.Match] = { + def matchGen( + argGen0: Gen[NonBinding], + bodyGen: Gen[Declaration] + ): Gen[Declaration.Match] = { import Declaration._ val padBody = optIndent(bodyGen) // args can't have raw annotations: val argGen = argGen0.map { - case ann@Annotation(_, _) => Parens(ann)(emptyRegion) - case notAnn => notAnn + case ann @ Annotation(_, _) => Parens(ann)(emptyRegion) + case notAnn => notAnn } val genCase: Gen[(Pattern.Parsed, OptIndent[Declaration])] = @@ -534,7 +664,10 @@ object Generators { for { cnt <- Gen.choose(1, 2) - kind <- Gen.frequency((10, Gen.const(RecursionKind.NonRecursive)), (1, Gen.const(RecursionKind.Recursive))) + kind <- Gen.frequency( + (10, Gen.const(RecursionKind.NonRecursive)), + (1, Gen.const(RecursionKind.Recursive)) + ) expr <- argGen cases <- optIndent(nonEmptyN(genCase, cnt)) } yield Match(kind, expr, cases)(emptyRegion) @@ -546,7 +679,9 @@ object Generators { val fixa = a match { // matches binds tighter than all these - case Lambda(_, _) | IfElse(_, _) | ApplyOp(_, _, _) | Match(_, _, _) | Ternary(_, _, _) => Parens(a)(emptyRegion) + case Lambda(_, _) | IfElse(_, _) | ApplyOp(_, _, _) | Match(_, _, _) | + Ternary(_, _, _) => + Parens(a)(emptyRegion) case _ => a } Matches(fixa, p)(emptyRegion) @@ -554,12 +689,13 @@ object Generators { val genLit: Gen[Lit] = { val str = for { - //q <- Gen.oneOf('\'', '"') - //str <- Arbitrary.arbitrary[String] + // q <- Gen.oneOf('\'', '"') + // str <- Arbitrary.arbitrary[String] str <- lowerIdent // TODO } yield Lit.Str(str) - val bi = Arbitrary.arbitrary[BigInt].map { bi => Lit.Integer(bi.bigInteger) } + val bi = + Arbitrary.arbitrary[BigInt].map { bi => Lit.Integer(bi.bigInteger) } Gen.oneOf(str, bi) } @@ -576,19 +712,20 @@ object Generators { Gen.frequency( (1, consDeclGen), (2, varGen), - (1, genLit.map(Declaration.Literal(_)(emptyRegion)))) + (1, genLit.map(Declaration.Literal(_)(emptyRegion))) + ) def annGen(g: Gen[NonBinding]): Gen[Declaration.Annotation] = { import Declaration._ Gen.zip(typeRefGen, g).map { - case (t, r@(Var(_) | Apply(_, _, _) | Parens(_))) => Annotation(r, t)(emptyRegion) + case (t, r @ (Var(_) | Apply(_, _, _) | Parens(_))) => + Annotation(r, t)(emptyRegion) case (t, wrap) => Annotation(Parens(wrap)(emptyRegion), t)(emptyRegion) } } - /** - * Generate a Declaration that can be parsed as a pattern - */ + /** Generate a Declaration that can be parsed as a pattern + */ def patternDecl(depth: Int): Gen[NonBinding] = { import Declaration._ val recur = Gen.lzy(patternDecl(depth - 1)) @@ -596,12 +733,14 @@ object Generators { val applyCons = applyGen(consDeclGen, recur, Gen.const(false)) if (depth <= 0) unnestedDeclGen - else Gen.frequency( - (12, unnestedDeclGen), - (2, applyCons), - (1, recur.map(Parens(_)(emptyRegion))), - (1, annGen(recur)), - (1, genListLangCons(varGen, recur).map(ListDecl(_)(emptyRegion)))) + else + Gen.frequency( + (12, unnestedDeclGen), + (2, applyCons), + (1, recur.map(Parens(_)(emptyRegion))), + (1, annGen(recur)), + (1, genListLangCons(varGen, recur).map(ListDecl(_)(emptyRegion))) + ) } def simpleDecl(depth: Int): Gen[NonBinding] = { @@ -611,33 +750,45 @@ object Generators { val recur = Gen.lzy(simpleDecl(depth - 1)) if (depth <= 0) unnested - else Gen.frequency( - (13, unnested), - (2, lambdaGen(recur)), - (2, applyGen(recur)), - (1, applyOpGen(recur)), - (1, genStringDecl(recur)), - (1, listGen(recur)), - (1, dictGen(recur)), - (1, annGen(recur)), - (1, Gen.choose(0, 4).flatMap(Gen.listOfN(_, recur)).map(TupleCons(_)(emptyRegion))) - ) + else + Gen.frequency( + (13, unnested), + (2, lambdaGen(recur)), + (2, applyGen(recur)), + (1, applyOpGen(recur)), + (1, genStringDecl(recur)), + (1, listGen(recur)), + (1, dictGen(recur)), + (1, annGen(recur)), + ( + 1, + Gen + .choose(0, 4) + .flatMap(Gen.listOfN(_, recur)) + .map(TupleCons(_)(emptyRegion)) + ) + ) } def genRecordArg(dgen: Gen[NonBinding]): Gen[Declaration.RecordArg] = - Gen.zip(bindIdentGen, Gen.option(dgen)) + Gen + .zip(bindIdentGen, Gen.option(dgen)) .map { - case (b, None) => Declaration.RecordArg.Simple(b) + case (b, None) => Declaration.RecordArg.Simple(b) case (b, Some(decl)) => Declaration.RecordArg.Pair(b, decl) } - def genRecordDeclaration(dgen: Gen[NonBinding]): Gen[Declaration.RecordConstructor] = { + def genRecordDeclaration( + dgen: Gen[NonBinding] + ): Gen[Declaration.RecordConstructor] = { val args = for { tailSize <- Gen.choose(0, 4) args <- nonEmptyN(genRecordArg(dgen), tailSize) } yield args - Gen.zip(consIdentGen, args).map { case (c, a) => Declaration.RecordConstructor(c, a)(emptyRegion) } + Gen.zip(consIdentGen, args).map { case (c, a) => + Declaration.RecordConstructor(c, a)(emptyRegion) + } } def genNonBinding(depth: Int): Gen[NonBinding] = { @@ -648,28 +799,37 @@ object Generators { val recur = Gen.lzy(genDeclaration(depth - 1)) val recNon = Gen.lzy(genNonBinding(depth - 1)) if (depth <= 0) unnested - else Gen.frequency( - (14, unnested), - (2, lambdaGen(recNon)), - (2, applyGen(recNon)), - (1, applyOpGen(simpleDecl(depth - 1))), - (1, ifElseGen(recNon, recur)), - (1, ternaryGen(recNon)), - (1, genStringDecl(recNon)), - (1, listGen(recNon)), - (1, dictGen(recNon)), - (1, matchGen(recNon, recur)), - (1, matchesGen(recNon)), - (1, Gen.choose(0, 4).flatMap(Gen.listOfN(_, recNon)).map(TupleCons(_)(emptyRegion))), - (1, genRecordDeclaration(recNon)) - ) + else + Gen.frequency( + (14, unnested), + (2, lambdaGen(recNon)), + (2, applyGen(recNon)), + (1, applyOpGen(simpleDecl(depth - 1))), + (1, ifElseGen(recNon, recur)), + (1, ternaryGen(recNon)), + (1, genStringDecl(recNon)), + (1, listGen(recNon)), + (1, dictGen(recNon)), + (1, matchGen(recNon, recur)), + (1, matchesGen(recNon)), + ( + 1, + Gen + .choose(0, 4) + .flatMap(Gen.listOfN(_, recNon)) + .map(TupleCons(_)(emptyRegion)) + ), + (1, genRecordDeclaration(recNon)) + ) } def makeComment(c: CommentStatement[Padding[Declaration]]): Declaration = { import Declaration._ c.on.padded match { case nb: NonBinding => - CommentNB(CommentStatement(c.message, Padding(c.on.lines, nb)))(emptyRegion) + CommentNB(CommentStatement(c.message, Padding(c.on.lines, nb)))( + emptyRegion + ) case _ => Comment(c)(emptyRegion) } @@ -681,18 +841,29 @@ object Generators { val unnested = unnestedDeclGen val pat: Gen[Pattern.Parsed] = bindIdentGen.map(Pattern.Var(_)) - //val pat = genPattern(0) + // val pat = genPattern(0) val recur = Gen.lzy(genDeclaration(depth - 1)) val recNon = Gen.lzy(genNonBinding(depth - 1)) if (depth <= 0) unnested - else Gen.frequency( - (3, genNonBinding(depth)), - (1, commentGen(padding(recur, 1)).map(makeComment)), // make sure we have 1 space to prevent comments following each other - (1, defGen(Gen.zip(optIndent(recur), padding(recur, 1))).map(DefFn(_)(emptyRegion))), - (1, bindGen(pat, recNon, padding(recur, 1)).map(Binding(_)(emptyRegion))), - (1, leftApplyGen(pat, recNon, recur)) - ) + else + Gen.frequency( + (3, genNonBinding(depth)), + ( + 1, + commentGen(padding(recur, 1)).map(makeComment) + ), // make sure we have 1 space to prevent comments following each other + ( + 1, + defGen(Gen.zip(optIndent(recur), padding(recur, 1))) + .map(DefFn(_)(emptyRegion)) + ), + ( + 1, + bindGen(pat, recNon, padding(recur, 1)).map(Binding(_)(emptyRegion)) + ), + (1, leftApplyGen(pat, recNon, recur)) + ) } implicit val shrinkDecl: Shrink[Declaration] = @@ -705,7 +876,7 @@ object Generators { case Apply(fn, args, _) => val next = fn #:: args.toList.toStream next.flatMap(apply _) - case ao@ApplyOp(left, _, right) => + case ao @ ApplyOp(left, _, right) => left #:: ao.opVar #:: right #:: Stream.empty case Binding(b) => val next = b.value #:: b.in.padded #:: Stream.empty @@ -723,26 +894,29 @@ object Generators { // todo, we should really interleave shrinking r and b r #:: b.padded #:: Stream.empty case Match(_, _, args) => - args.get.toList.toStream.flatMap { - case (_, decl) => decl.get #:: apply(decl.get) + args.get.toList.toStream.flatMap { case (_, decl) => + decl.get #:: apply(decl.get) } case Matches(a, _) => a #:: apply(a) // the rest can't be shrunk - case Comment(c) => c.on.padded #:: Stream.empty - case CommentNB(c) => c.on.padded #:: Stream.empty + case Comment(c) => c.on.padded #:: Stream.empty + case CommentNB(c) => c.on.padded #:: Stream.empty case Lambda(_, body) => body #:: Stream.empty - case Literal(_) => Stream.empty - case Parens(_) => + case Literal(_) => Stream.empty + case Parens(_) => // by removing parens we can make invalid // expressions Stream.empty case TupleCons(Nil) => Stream.empty - case TupleCons(h :: tail) => h #:: TupleCons(tail)(emptyRegion) #:: apply(TupleCons(tail)(emptyRegion)) + case TupleCons(h :: tail) => + h #:: TupleCons(tail)(emptyRegion) #:: apply( + TupleCons(tail)(emptyRegion) + ) case Var(_) => Stream.empty case StringDecl(parts) => parts.toList.toStream.map { - case Left(nb) => nb + case Left(nb) => nb case Right((r, str)) => Literal(Lit.Str(str))(r) } case ListDecl(ListLang.Cons(items)) => @@ -757,19 +931,25 @@ object Generators { def head: Stream[Declaration] = args.head match { case RecordArg.Pair(n, d) => Stream(Var(n)(emptyRegion), d) - case RecordArg.Simple(n) => Stream(Var(n)(emptyRegion)) + case RecordArg.Simple(n) => Stream(Var(n)(emptyRegion)) } - def tailStream(of: NonEmptyList[RecordArg]): Stream[NonEmptyList[RecordArg]] = + def tailStream( + of: NonEmptyList[RecordArg] + ): Stream[NonEmptyList[RecordArg]] = NonEmptyList.fromList(of.tail) match { case None => Stream.empty case Some(tailArgs) => - tailArgs #:: tailStream(tailArgs) #::: tailStream(NonEmptyList(of.head, tailArgs.tail)) + tailArgs #:: tailStream(tailArgs) #::: tailStream( + NonEmptyList(of.head, tailArgs.tail) + ) } Var(n)(emptyRegion) #:: head #::: - tailStream(args).map(RecordConstructor(n, _)(emptyRegion): Declaration) // type annotation for scala 2.11 + tailStream(args).map( + RecordConstructor(n, _)(emptyRegion): Declaration + ) // type annotation for scala 2.11 } }) @@ -778,13 +958,14 @@ object Generators { import Statement._ def apply(s: Statement): Stream[Statement] = s match { - case Bind(bs@BindingStatement(_, d, _)) => + case Bind(bs @ BindingStatement(_, d, _)) => shrinkDecl.shrink(d).collect { case sd: NonBinding => Bind(bs.copy(value = sd))(emptyRegion) } case Def(ds) => val body = ds.result - body.traverse(shrinkDecl.shrink(_)) + body + .traverse(shrinkDecl.shrink(_)) .map { bod => Def(ds.copy(result = bod))(emptyRegion) } @@ -792,17 +973,21 @@ object Generators { } }) - val constructorGen: Gen[(Identifier.Constructor, List[(Identifier.Bindable, Option[TypeRef])])] = + val constructorGen: Gen[ + (Identifier.Constructor, List[(Identifier.Bindable, Option[TypeRef])]) + ] = for { name <- consIdentGen args <- smallList(argGen) } yield (name, args) val genTypeArgs: Gen[List[(TypeRef.TypeVar, Option[Kind.Arg])]] = - smallList(Gen.zip(typeRefVarGen, Gen.option(NTypeGen.genKindArg))).map(_.distinctBy(_._1)) + smallList(Gen.zip(typeRefVarGen, Gen.option(NTypeGen.genKindArg))) + .map(_.distinctBy(_._1)) val genStruct: Gen[Statement] = - Gen.zip(constructorGen, genTypeArgs) + Gen + .zip(constructorGen, genTypeArgs) .map { case ((name, args), ta) => Statement.Struct(name, NonEmptyList.fromList(ta), args)(emptyRegion) } @@ -836,14 +1021,22 @@ object Generators { // TODO make more powerful val pat: Gen[Pattern.Parsed] = genPattern(1) Gen.frequency( - (1, bindGen(pat, nonB, Gen.const(())).map(Statement.Bind(_)(emptyRegion))), + ( + 1, + bindGen(pat, nonB, Gen.const(())).map(Statement.Bind(_)(emptyRegion)) + ), (1, commentGen(Gen.const(())).map(Statement.Comment(_)(emptyRegion))), (1, defGen(optIndent(decl)).map(Statement.Def(_)(emptyRegion))), (1, genStruct), (1, genExternalStruct), (1, genExternalDef), (1, genEnum), - (1, padding(Gen.const(()), 1).map(Statement.PaddingStatement(_)(emptyRegion)))) + ( + 1, + padding(Gen.const(()), 1) + .map(Statement.PaddingStatement(_)(emptyRegion)) + ) + ) } def genStatements(depth: Int, maxLength: Int): Gen[List[Statement]] = { @@ -855,12 +1048,22 @@ object Generators { */ def combineDuplicates(stmts: List[Statement]): List[Statement] = stmts match { - case Nil => Nil + case Nil => Nil case h :: Nil => h :: Nil - case PaddingStatement(Padding(a, _)) :: PaddingStatement(Padding(b, _)) :: rest => - combineDuplicates(PaddingStatement(Padding(a + b, ()))(emptyRegion) :: rest) - case Comment(CommentStatement(lines1, _)) :: Comment(CommentStatement(lines2, _)) :: rest => - combineDuplicates(Comment(CommentStatement(lines1 ::: lines2, ()))(emptyRegion) :: rest) + case PaddingStatement(Padding(a, _)) :: PaddingStatement( + Padding(b, _) + ) :: rest => + combineDuplicates( + PaddingStatement(Padding(a + b, ()))(emptyRegion) :: rest + ) + case Comment(CommentStatement(lines1, _)) :: Comment( + CommentStatement(lines2, _) + ) :: rest => + combineDuplicates( + Comment(CommentStatement(lines1 ::: lines2, ()))( + emptyRegion + ) :: rest + ) case h1 :: rest => h1 :: combineDuplicates(rest) } @@ -897,24 +1100,27 @@ object Generators { Gen.oneOf( bindIdentGen.map(ExportedName.Binding(_, ())), consIdentGen.map(ExportedName.TypeName(_, ())), - consIdentGen.map(ExportedName.Constructor(_, ()))) + consIdentGen.map(ExportedName.Constructor(_, ())) + ) def smallList[A](g: Gen[A]): Gen[List[A]] = - Gen.choose(0, 8).flatMap(Gen.listOfN(_, g)) def smallNonEmptyList[A](g: Gen[A], maxLen: Int): Gen[NonEmptyList[A]] = // bias to small numbers - Gen.geometric(2.0) + Gen + .geometric(2.0) .flatMap { case n if n <= 0 => g.map(NonEmptyList.one) case n => - Gen.zip(g, Gen.listOfN((n - 1) min (maxLen - 1), g)) + Gen + .zip(g, Gen.listOfN((n - 1) min (maxLen - 1), g)) .map { case (h, t) => NonEmptyList(h, t) } } def smallDistinctByList[A, B](g: Gen[A])(fn: A => B): Gen[List[A]] = - Gen.choose(0, 8) + Gen + .choose(0, 8) .flatMap(Gen.listOfN(_, g)) .map(graph.Tree.distinctBy(_)(fn)) @@ -927,8 +1133,11 @@ object Generators { body <- genStatements(depth, 10) } yield Package(p, imports, exports, body) - - def genDefinedType[A](p: PackageName, inner: Gen[A], genType: Gen[rankn.Type]): Gen[rankn.DefinedType[A]] = + def genDefinedType[A]( + p: PackageName, + inner: Gen[A], + genType: Gen[rankn.Type] + ): Gen[rankn.DefinedType[A]] = for { t <- typeNameGen paramKeys <- smallList(NTypeGen.genBound).map(_.distinct) @@ -953,13 +1162,16 @@ object Generators { } yield ExportedName.Binding(n, Referant.Value(t)) te.allDefinedTypes match { - case Nil => bind(NTypeGen.genDepth03) + case Nil => bind(NTypeGen.genDepth03) case dts0 => // only make one of each type val dts = dts0.map { dt => (dt.name.ident, dt) }.toMap.values.toList - val b = bind(Gen.oneOf(NTypeGen.genDepth03, Gen.oneOf(dts).map(_.toTypeTyConst))) - val genExpT = Gen.oneOf(dts) + val b = bind( + Gen.oneOf(NTypeGen.genDepth03, Gen.oneOf(dts).map(_.toTypeTyConst)) + ) + val genExpT = Gen + .oneOf(dts) .map { dt => ExportedName.TypeName(dt.name.ident, Referant.DefinedT(dt)) } @@ -970,7 +1182,10 @@ object Generators { val c = for { dt <- Gen.oneOf(nonEmpty) cf <- Gen.oneOf(dt.constructors) - } yield ExportedName.Constructor(cf.name, Referant.Constructor(dt, cf)) + } yield ExportedName.Constructor( + cf.name, + Referant.Constructor(dt, cf) + ) Gen.oneOf(b, genExpT, c) } } @@ -979,55 +1194,118 @@ object Generators { val interfaceGen: Gen[Package.Interface] = for { p <- packageNameGen - te <- typeEnvGen(p, Gen.oneOf(Kind.Type.co, Kind.Type.phantom, Kind.Type.contra, Kind.Type.in)) + te <- typeEnvGen( + p, + Gen.oneOf( + Kind.Type.co, + Kind.Type.phantom, + Kind.Type.contra, + Kind.Type.in + ) + ) exs0 <- smallList(exportGen(te)) - exs = exs0.map { ex => (ex.name, ex) }.toMap.values.toList // don't duplicate exported names + exs = exs0 + .map { ex => (ex.name, ex) } + .toMap + .values + .toList // don't duplicate exported names } yield Package(p, Nil, exs, ()) - /** - * This is a totally random, and not well typed expression. - * It is suitable for some tests, but it is not a valid output - * of a typechecking process - */ - def genTypedExpr[A](genTag: Gen[A], depth: Int, typeGen: Gen[rankn.Type]): Gen[TypedExpr[A]] = { + /** This is a totally random, and not well typed expression. It is suitable + * for some tests, but it is not a valid output of a typechecking process + */ + def genTypedExpr[A]( + genTag: Gen[A], + depth: Int, + typeGen: Gen[rankn.Type] + ): Gen[TypedExpr[A]] = { val recurse = Gen.lzy(genTypedExpr(genTag, depth - 1, typeGen)) - val lit = Gen.zip(genLit, NTypeGen.genDepth03, genTag).map { case (l, tpe, tag) => TypedExpr.Literal(l, tpe, tag) } + val lit = Gen.zip(genLit, NTypeGen.genDepth03, genTag).map { + case (l, tpe, tag) => TypedExpr.Literal(l, tpe, tag) + } // only literal doesn't recurse if (depth <= 0) lit else { val genGeneric = - Gen.zip(Generators.nonEmpty(Gen.zip(NTypeGen.genBound, NTypeGen.genKind)), recurse) + Gen + .zip( + Generators.nonEmpty(Gen.zip(NTypeGen.genBound, NTypeGen.genKind)), + recurse + ) .map { case (vs, t) => TypedExpr.Generic(vs, t) } val ann = - Gen.zip(recurse, typeGen) + Gen + .zip(recurse, typeGen) .map { case (te, tpe) => TypedExpr.Annotation(te, tpe) } val lam = - Gen.zip(smallNonEmptyList(Gen.zip(bindIdentGen, typeGen), 8), recurse, genTag) - .map { case (args, res, tag) => TypedExpr.AnnotatedLambda(args, res, tag) } + Gen + .zip( + smallNonEmptyList(Gen.zip(bindIdentGen, typeGen), 8), + recurse, + genTag + ) + .map { case (args, res, tag) => + TypedExpr.AnnotatedLambda(args, res, tag) + } val localGen = - Gen.zip(bindIdentGen, typeGen, genTag) + Gen + .zip(bindIdentGen, typeGen, genTag) .map { case (n, t, tag) => TypedExpr.Local(n, t, tag) } val globalGen = - Gen.zip(packageNameGen, identifierGen, typeGen, genTag) + Gen + .zip(packageNameGen, identifierGen, typeGen, genTag) .map { case (p, n, t, tag) => TypedExpr.Global(p, n, t, tag) } val app = - Gen.zip(recurse, smallNonEmptyList(recurse, 8), typeGen, genTag) - .map { case (fn, args, tpe, tag) => TypedExpr.App(fn, args, tpe, tag) } + Gen + .zip(recurse, smallNonEmptyList(recurse, 8), typeGen, genTag) + .map { case (fn, args, tpe, tag) => + TypedExpr.App(fn, args, tpe, tag) + } val let = - Gen.zip(bindIdentGen, recurse, recurse, Gen.oneOf(RecursionKind.NonRecursive, RecursionKind.Recursive), genTag) - .map { case (n, ex, in, rec, tag) => TypedExpr.Let(n, ex, in, rec, tag) } + Gen + .zip( + bindIdentGen, + recurse, + recurse, + Gen.oneOf(RecursionKind.NonRecursive, RecursionKind.Recursive), + genTag + ) + .map { case (n, ex, in, rec, tag) => + TypedExpr.Let(n, ex, in, rec, tag) + } val matchGen = - Gen.zip(recurse, Gen.choose(1, 4).flatMap(nonEmptyN(Gen.zip(genCompiledPattern(depth), recurse), _)), genTag) - .map { case (arg, branches, tag) => TypedExpr.Match(arg, branches, tag) } + Gen + .zip( + recurse, + Gen + .choose(1, 4) + .flatMap( + nonEmptyN(Gen.zip(genCompiledPattern(depth), recurse), _) + ), + genTag + ) + .map { case (arg, branches, tag) => + TypedExpr.Match(arg, branches, tag) + } - Gen.oneOf(genGeneric, ann, lam, localGen, globalGen, app, let, lit, matchGen) + Gen.oneOf( + genGeneric, + ann, + lam, + localGen, + globalGen, + app, + let, + lit, + matchGen + ) } } @@ -1044,7 +1322,8 @@ object Generators { def loop(idx: Int): Gen[List[Int]] = if (idx >= size) Gen.const(Nil) else - Gen.zip(Gen.choose(idx, size - 1), loop(idx + 1)) + Gen + .zip(Gen.choose(idx, size - 1), loop(idx + 1)) .map { case (h, tail) => h :: tail } loop(0).map { swaps => @@ -1056,21 +1335,32 @@ object Generators { } ary.toList } - } + } - def genOnePackage[A](genA: Gen[A], existing: Map[PackageName, Package.Typed[A]]): Gen[Package.Typed[A]] = { + def genOnePackage[A]( + genA: Gen[A], + existing: Map[PackageName, Package.Typed[A]] + ): Gen[Package.Typed[A]] = { val genDeps: Gen[Map[PackageName, Package.Typed[A]]] = Gen.frequency( - (5, Gen.const(Map.empty)), // usually have no deps, otherwise the graph gets enormous + ( + 5, + Gen.const(Map.empty) + ), // usually have no deps, otherwise the graph gets enormous (1, shuffle(existing.toList).map(_.take(2).toMap)) ) - def impFromExp(exp: List[(Package.Interface, ExportedName[Referant[Kind.Arg]])]): Gen[List[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]]] = - exp.groupBy(_._1) + def impFromExp( + exp: List[(Package.Interface, ExportedName[Referant[Kind.Arg]])] + ): Gen[List[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]]] = + exp + .groupBy(_._1) .toList .traverse { case (p, exps) => - val genImps: Gen[List[ImportedName[NonEmptyList[Referant[Kind.Arg]]]]] = - exps.groupBy(_._2.name) + val genImps + : Gen[List[ImportedName[NonEmptyList[Referant[Kind.Arg]]]]] = + exps + .groupBy(_._2.name) .iterator .toList .traverse { case (ident, exps) => @@ -1100,7 +1390,9 @@ object Generators { } } - val genImports: Gen[List[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]]] = + val genImports: Gen[ + List[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]] + ] = genDeps.flatMap { packs => val exps: List[(Package.Interface, ExportedName[Referant[Kind.Arg]])] = (for { @@ -1115,90 +1407,135 @@ object Generators { } yield imp } - def definedTypesFromImp(i: Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]): List[rankn.Type.Const] = + def definedTypesFromImp( + i: Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]] + ): List[rankn.Type.Const] = i.items.toList.flatMap { in => in.tag.toList.flatMap { - case Referant.DefinedT(dt) => dt.toTypeConst :: Nil + case Referant.DefinedT(dt) => dt.toTypeConst :: Nil case Referant.Constructor(dt, _) => dt.toTypeConst :: Nil - case Referant.Value(_) => Nil + case Referant.Value(_) => Nil } } - def genTypeEnv(pn: PackageName, - imps: List[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]]): StateT[Gen, (rankn.TypeEnv[Kind.Arg], Set[Identifier.Bindable]), Unit] = - StateT.get[Gen, (rankn.TypeEnv[Kind.Arg], Set[Identifier.Bindable])] - .flatMap { case (te, extDefs) => - StateT.liftF(Gen.choose(0, 9)) - .flatMap { - case 0 => - // 1 in 10 chance of stopping - StateT.pure[Gen, (rankn.TypeEnv[Kind.Arg], Set[Identifier.Bindable]), Unit](()) - case _ => - // add something: - val tyconsts = - te.allDefinedTypes.map(_.toTypeConst) ++ - imps.flatMap(definedTypesFromImp) - val theseTypes = NTypeGen.genDepth(4, if (tyconsts.isEmpty) None else Some(Gen.oneOf(tyconsts))) - val genV: Gen[Kind.Arg] = - Gen.oneOf(Kind.Type.co, Kind.Type.contra, Kind.Type.in, Kind.Type.phantom) - val genDT = genDefinedType(pn, genV, theseTypes) - val genEx: Gen[(Identifier.Bindable, rankn.Type)] = - Gen.zip(bindIdentGen, theseTypes) - - // we can do one of the following: - // 1: add an external value - // 2: add a defined type - StateT.liftF(Gen.frequency( - (5, genDT.map { dt => (te.addDefinedTypeAndConstructors(dt), extDefs) }), - (1, genEx.map { case (b, t) => (te.addExternalValue(pn, b, t), extDefs + b) }))) - .flatMap(StateT.set(_)) - } - } + def genTypeEnv( + pn: PackageName, + imps: List[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]] + ): StateT[Gen, (rankn.TypeEnv[Kind.Arg], Set[Identifier.Bindable]), Unit] = + StateT + .get[Gen, (rankn.TypeEnv[Kind.Arg], Set[Identifier.Bindable])] + .flatMap { case (te, extDefs) => + StateT + .liftF(Gen.choose(0, 9)) + .flatMap { + case 0 => + // 1 in 10 chance of stopping + StateT.pure[ + Gen, + (rankn.TypeEnv[Kind.Arg], Set[Identifier.Bindable]), + Unit + ](()) + case _ => + // add something: + val tyconsts = + te.allDefinedTypes.map(_.toTypeConst) ++ + imps.flatMap(definedTypesFromImp) + val theseTypes = NTypeGen.genDepth( + 4, + if (tyconsts.isEmpty) None else Some(Gen.oneOf(tyconsts)) + ) + val genV: Gen[Kind.Arg] = + Gen.oneOf( + Kind.Type.co, + Kind.Type.contra, + Kind.Type.in, + Kind.Type.phantom + ) + val genDT = genDefinedType(pn, genV, theseTypes) + val genEx: Gen[(Identifier.Bindable, rankn.Type)] = + Gen.zip(bindIdentGen, theseTypes) + + // we can do one of the following: + // 1: add an external value + // 2: add a defined type + StateT + .liftF( + Gen.frequency( + ( + 5, + genDT.map { dt => + (te.addDefinedTypeAndConstructors(dt), extDefs) + } + ), + ( + 1, + genEx.map { case (b, t) => + (te.addExternalValue(pn, b, t), extDefs + b) + } + ) + ) + ) + .flatMap(StateT.set(_)) + } + } - def genLets(te: rankn.TypeEnv[Kind.Arg], - exts: Set[Identifier.Bindable]): Gen[List[(Identifier.Bindable, RecursionKind, TypedExpr[A])]] = { - val allTC = te.allDefinedTypes.map(_.toTypeConst) - val theseTypes = NTypeGen.genDepth(4, if (allTC.isEmpty) None else Some(Gen.oneOf(allTC))) - val oneLet = Gen.zip(bindIdentGen.filter { b => !exts(b) }, - Gen.oneOf(RecursionKind.NonRecursive, RecursionKind.Recursive), - genTypedExpr(genA, 4, theseTypes)) + def genLets( + te: rankn.TypeEnv[Kind.Arg], + exts: Set[Identifier.Bindable] + ): Gen[List[(Identifier.Bindable, RecursionKind, TypedExpr[A])]] = { + val allTC = te.allDefinedTypes.map(_.toTypeConst) + val theseTypes = NTypeGen.genDepth( + 4, + if (allTC.isEmpty) None else Some(Gen.oneOf(allTC)) + ) + val oneLet = Gen.zip( + bindIdentGen.filter { b => !exts(b) }, + Gen.oneOf(RecursionKind.NonRecursive, RecursionKind.Recursive), + genTypedExpr(genA, 4, theseTypes) + ) - Gen.choose(0, 6).flatMap(Gen.listOfN(_, oneLet)) - } + Gen.choose(0, 6).flatMap(Gen.listOfN(_, oneLet)) + } def genProg( - pn: PackageName, - imps: List[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]]): Gen[Program[rankn.TypeEnv[Kind.Arg], TypedExpr[A], Any]] = - genTypeEnv(pn, imps) - .runS((rankn.TypeEnv.empty, Set.empty)) - .flatMap { case (te, b) => - genLets(te, b).map(Program(te, _, b.toList.sorted, ())) - } + pn: PackageName, + imps: List[Import[Package.Interface, NonEmptyList[Referant[Kind.Arg]]]] + ): Gen[Program[rankn.TypeEnv[Kind.Arg], TypedExpr[A], Any]] = + genTypeEnv(pn, imps) + .runS((rankn.TypeEnv.empty, Set.empty)) + .flatMap { case (te, b) => + genLets(te, b).map(Program(te, _, b.toList.sorted, ())) + } /* * Exports are types, constructors, or values */ - def genExports(pn: PackageName, p: Program[rankn.TypeEnv[Kind.Arg], TypedExpr[A], Any]): Gen[List[ExportedName[Referant[Kind.Arg]]]] = { + def genExports( + pn: PackageName, + p: Program[rankn.TypeEnv[Kind.Arg], TypedExpr[A], Any] + ): Gen[List[ExportedName[Referant[Kind.Arg]]]] = { def expnames: List[ExportedName[Referant[Kind.Arg]]] = p.lets.map { case (n, _, te) => ExportedName.Binding(n, Referant.Value(te.getType)) } def exts: List[ExportedName[Referant[Kind.Arg]]] = p.externalDefs.flatMap { n => - p.types.getValue(pn, n).map { t => ExportedName.Binding(n, Referant.Value(t)) } + p.types.getValue(pn, n).map { t => + ExportedName.Binding(n, Referant.Value(t)) + } } def cons: List[ExportedName[Referant[Kind.Arg]]] = p.types.allDefinedTypes.flatMap { dt => if (dt.packageName == pn) { - val dtex = ExportedName.TypeName(dt.name.ident, Referant.DefinedT(dt)) + val dtex = + ExportedName.TypeName(dt.name.ident, Referant.DefinedT(dt)) val cons = dt.constructors.map { cf => ExportedName.Constructor(cf.name, Referant.Constructor(dt, cf)) } dtex :: cons - } - else Nil + } else Nil } for { @@ -1217,20 +1554,32 @@ object Generators { } yield Package(pn, imps, exps, prog) } - def genPackagesSt[A](genA: Gen[A], maxSize: Int): StateT[Gen, Map[PackageName, Package.Typed[A]], Unit] = - StateT.get[Gen, Map[PackageName, Package.Typed[A]]] + def genPackagesSt[A]( + genA: Gen[A], + maxSize: Int + ): StateT[Gen, Map[PackageName, Package.Typed[A]], Unit] = + StateT + .get[Gen, Map[PackageName, Package.Typed[A]]] .flatMap { m => if (m.size >= maxSize) StateT.pure(()) else { // make one more and try again for { - p <- StateT.liftF[Gen, Map[PackageName, Package.Typed[A]], Package.Typed[A]](genOnePackage(genA, m)) - _ <- StateT.set[Gen, Map[PackageName, Package.Typed[A]]](m.updated(p.name, p)) + p <- StateT + .liftF[Gen, Map[PackageName, Package.Typed[A]], Package.Typed[A]]( + genOnePackage(genA, m) + ) + _ <- StateT.set[Gen, Map[PackageName, Package.Typed[A]]]( + m.updated(p.name, p) + ) _ <- genPackagesSt(genA, maxSize) } yield () } } - def genPackage[A](genA: Gen[A], maxSize: Int): Gen[Map[PackageName, Package.Typed[A]]] = + def genPackage[A]( + genA: Gen[A], + maxSize: Int + ): Gen[Map[PackageName, Package.Typed[A]]] = genPackagesSt(genA, maxSize).runS(Map.empty) } diff --git a/core/src/test/scala/org/bykn/bosatsu/GenJson.scala b/core/src/test/scala/org/bykn/bosatsu/GenJson.scala index 1b3a24825..a527929fe 100644 --- a/core/src/test/scala/org/bykn/bosatsu/GenJson.scala +++ b/core/src/test/scala/org/bykn/bosatsu/GenJson.scala @@ -7,33 +7,38 @@ object GenJson { val genJsonNumber: Gen[Json.JNumberStr] = { def cat(gs: List[Gen[String]]): Gen[String] = gs match { - case Nil => Gen.const("") + case Nil => Gen.const("") case h :: tail => Gen.zip(h, cat(tail)).map { case (a, b) => a + b } } val digit09 = Gen.oneOf('0' to '9').map(_.toString) val digit19 = Gen.oneOf('1' to '9').map(_.toString) val digits = Gen.listOf(digit09).map(_.mkString) - val digits1 = Gen.zip(digit09, Gen.listOf(digit09)).map { case (h, t) => (h :: t).mkString } + val digits1 = Gen.zip(digit09, Gen.listOf(digit09)).map { case (h, t) => + (h :: t).mkString + } val int = Gen.frequency( (1, Gen.const("0")), - (20, Gen.zip(digit19, digits).map { case (h, t) => h + t })) + (20, Gen.zip(digit19, digits).map { case (h, t) => h + t }) + ) val frac = digits1.map("." + _) def opt(g: Gen[String]): Gen[String] = Gen.oneOf(true, false).flatMap { - case true => g + case true => g case false => Gen.const("") } val exp = cat(List(Gen.oneOf("e", "E"), opt(Gen.oneOf("+", "-")), digits1)) - cat(List(opt(Gen.const("-")), int, opt(frac), opt(exp))).map(Json.JNumberStr(_)) + cat(List(opt(Gen.const("-")), int, opt(frac), opt(exp))) + .map(Json.JNumberStr(_)) } def genJson(depth: Int): Gen[Json] = { val genString = Gen.listOf(Gen.choose(1.toChar, 127.toChar)).map(_.mkString) val str = genString.map(Json.JString(_)) val nd1 = Arbitrary.arbitrary[Long].map { i => Json.JNumberStr(i.toString) } - val nd2 = Arbitrary.arbitrary[Double].map { d => Json.JNumberStr(d.toString) } + val nd2 = + Arbitrary.arbitrary[Double].map { d => Json.JNumberStr(d.toString) } val nd3 = Arbitrary.arbitrary[Int].map { i => Json.JNumberStr(i.toString) } val b = Gen.oneOf(Json.JBool(true), Json.JBool(false)) @@ -42,9 +47,12 @@ object GenJson { else { val recurse = Gen.lzy(genJson(depth - 1)) val collectionSize = Gen.choose(0, depth * depth) - val ary = collectionSize.flatMap(Gen.listOfN(_, recurse).map { l => Json.JArray(l.toVector) }) + val ary = collectionSize.flatMap( + Gen.listOfN(_, recurse).map { l => Json.JArray(l.toVector) } + ) val map = collectionSize.flatMap { sz => - Gen.listOfN(sz, Gen.zip(genString, recurse)) + Gen + .listOfN(sz, Gen.zip(genString, recurse)) .map { m => Json.JObject(m).normalize } } Gen.frequency((10, d0), (1, ary), (1, map)) @@ -54,16 +62,16 @@ object GenJson { implicit val arbJson: Arbitrary[Json] = Arbitrary(Gen.choose(0, 4).flatMap(genJson(_))) - implicit def shrinkJson( - implicit ss: Shrink[String] + implicit def shrinkJson(implicit + ss: Shrink[String] ): Shrink[Json] = Shrink[Json](new Function1[Json, Stream[Json]] { def apply(j: Json): Stream[Json] = { import Json._ j match { - case JString(str) => ss.shrink(str).map(JString(_)) - case JNumberStr(_) => Stream.empty - case JNull => Stream.empty + case JString(str) => ss.shrink(str).map(JString(_)) + case JNumberStr(_) => Stream.empty + case JNull => Stream.empty case JBool.True | JBool.False => Stream.empty case JArray(js) => (0 until js.size).toStream.map { sz => diff --git a/core/src/test/scala/org/bykn/bosatsu/GenValue.scala b/core/src/test/scala/org/bykn/bosatsu/GenValue.scala index 0f1f1b99d..509d2acef 100644 --- a/core/src/test/scala/org/bykn/bosatsu/GenValue.scala +++ b/core/src/test/scala/org/bykn/bosatsu/GenValue.scala @@ -10,8 +10,12 @@ object GenValue { Cogen[Int].contramap { (v: Value) => v.hashCode } lazy val genProd: Gen[ProductValue] = - Gen.lzy(Gen.oneOf(Gen.const(UnitValue), - genValue.flatMap { v => genProd.map(ConsValue(v, _)) })) + Gen.lzy( + Gen.oneOf( + Gen.const(UnitValue), + genValue.flatMap { v => genProd.map(ConsValue(v, _)) } + ) + ) lazy val genValue: Gen[Value] = { val recur = Gen.lzy(genValue) @@ -26,7 +30,8 @@ object GenValue { val genExt: Gen[Value] = Gen.oneOf( Gen.choose(Int.MinValue, Int.MaxValue).map(VInt(_)), - Arbitrary.arbitrary[String].map(Str(_))) + Arbitrary.arbitrary[String].map(Str(_)) + ) val genFn: Gen[FnValue] = { val fn: Gen[NonEmptyList[Value] => Value] = Gen.function1(recur)( diff --git a/core/src/test/scala/org/bykn/bosatsu/IntLaws.scala b/core/src/test/scala/org/bykn/bosatsu/IntLaws.scala index 64995f770..3e44c5f7d 100644 --- a/core/src/test/scala/org/bykn/bosatsu/IntLaws.scala +++ b/core/src/test/scala/org/bykn/bosatsu/IntLaws.scala @@ -2,7 +2,10 @@ package org.bykn.bosatsu import java.math.BigInteger import org.scalacheck.Gen -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.scalatest.funsuite.AnyFunSuite object IntLaws { @@ -20,26 +23,56 @@ class IntLaws extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = PropertyCheckConfiguration(minSuccessful = 50000) - //PropertyCheckConfiguration(minSuccessful = 5000) - //PropertyCheckConfiguration(minSuccessful = 500) + // PropertyCheckConfiguration(minSuccessful = 5000) + // PropertyCheckConfiguration(minSuccessful = 500) val genBI: Gen[BigInteger] = - Gen.choose(-128L, 128L) + Gen + .choose(-128L, 128L) .map(BigInteger.valueOf(_)) test("match python on some examples") { - assert(BigInteger.valueOf(4L) % BigInteger.valueOf(-3L) == BigInteger.valueOf(-2L)) - - assert(BigInteger.valueOf(-8L) % BigInteger.valueOf(-2L) == BigInteger.valueOf(0L)) - assert(BigInteger.valueOf(-8L) / BigInteger.valueOf(-2L) == BigInteger.valueOf(4L)) - - assert(BigInteger.valueOf(-4L) % BigInteger.valueOf(-3L) == BigInteger.valueOf(-1L)) - assert(BigInteger.valueOf(13L) % BigInteger.valueOf(3L) == BigInteger.valueOf(1L)) - assert(BigInteger.valueOf(-113L) / BigInteger.valueOf(16L) == BigInteger.valueOf(-8L)) - - - assert(BigInteger.valueOf(54L) % BigInteger.valueOf(-3L) == BigInteger.valueOf(0L)) - assert(BigInteger.valueOf(54L) / BigInteger.valueOf(-3L) == BigInteger.valueOf(-18L)) + assert( + BigInteger.valueOf(4L) % BigInteger.valueOf(-3L) == BigInteger.valueOf( + -2L + ) + ) + + assert( + BigInteger.valueOf(-8L) % BigInteger.valueOf(-2L) == BigInteger.valueOf( + 0L + ) + ) + assert( + BigInteger.valueOf(-8L) / BigInteger.valueOf(-2L) == BigInteger.valueOf( + 4L + ) + ) + + assert( + BigInteger.valueOf(-4L) % BigInteger.valueOf(-3L) == BigInteger.valueOf( + -1L + ) + ) + assert( + BigInteger.valueOf(13L) % BigInteger.valueOf(3L) == BigInteger.valueOf(1L) + ) + assert( + BigInteger.valueOf(-113L) / BigInteger.valueOf(16L) == BigInteger.valueOf( + -8L + ) + ) + + assert( + BigInteger.valueOf(54L) % BigInteger.valueOf(-3L) == BigInteger.valueOf( + 0L + ) + ) + assert( + BigInteger.valueOf(54L) / BigInteger.valueOf(-3L) == BigInteger.valueOf( + -18L + ) + ) } test("a = (a / b) * b + (a % b)") { @@ -107,7 +140,9 @@ class IntLaws extends AnyFunSuite { test("a / b <= a if b >= 0 and a >= 0") { forAll(genBI, genBI) { (a, b) => - if (b.compareTo(BigInteger.ZERO) >= 0 && a.compareTo(BigInteger.ZERO) >= 0) { + if ( + b.compareTo(BigInteger.ZERO) >= 0 && a.compareTo(BigInteger.ZERO) >= 0 + ) { val div = a / b assert(div.compareTo(a) <= 0, div) } @@ -133,7 +168,7 @@ class IntLaws extends AnyFunSuite { forAll(genBI, genBI) { (a, b) => val mod = a % b if (mod == BigInteger.ZERO) { - assert((a/b)*b == a) + assert((a / b) * b == a) } } } diff --git a/core/src/test/scala/org/bykn/bosatsu/JsonTest.scala b/core/src/test/scala/org/bykn/bosatsu/JsonTest.scala index 2d55427bc..1183e48a6 100644 --- a/core/src/test/scala/org/bykn/bosatsu/JsonTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/JsonTest.scala @@ -3,7 +3,10 @@ package org.bykn.bosatsu import cats.Eq import cats.implicits._ import org.scalacheck.Gen -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import TestUtils.typeEnvOf import rankn.{NTypeGen, Type, TypeEnv} @@ -14,7 +17,9 @@ import org.scalatest.funsuite.AnyFunSuite class JsonTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - PropertyCheckConfiguration(minSuccessful = if (Platform.isScalaJvm) 1000 else 20) + PropertyCheckConfiguration(minSuccessful = + if (Platform.isScalaJvm) 1000 else 20 + ) def law(j: Json) = assert(Parser.unsafeParse(Json.parser, j.render) == j) @@ -25,7 +30,10 @@ class JsonTest extends AnyFunSuite { .flatMap { te => val tyconsts = te.allDefinedTypes.map(_.toTypeConst) - val theseTypes = NTypeGen.genDepth(4, if (tyconsts.isEmpty) None else Some(Gen.oneOf(tyconsts))) + val theseTypes = NTypeGen.genDepth( + 4, + if (tyconsts.isEmpty) None else Some(Gen.oneOf(tyconsts)) + ) theseTypes.map((te, _)) } @@ -37,10 +45,19 @@ class JsonTest extends AnyFunSuite { optTE = if (none) None else Some(te) } yield (optTE, tpe) - test("test some example escapes") { - assert(Parser.unsafeParse(JsonStringUtil.escapedToken.string, "\\u0000") == "\\u0000") - assert(Parser.unsafeParse(JsonStringUtil.escapedString('\''), "'\\u0000'") == 0.toChar.toString) + assert( + Parser.unsafeParse( + JsonStringUtil.escapedToken.string, + "\\u0000" + ) == "\\u0000" + ) + assert( + Parser.unsafeParse( + JsonStringUtil.escapedString('\''), + "'\\u0000'" + ) == 0.toChar.toString + ) } test("we can parse all the json we generate") { @@ -51,13 +68,12 @@ class JsonTest extends AnyFunSuite { forAll(genJsonNumber)(law(_)) forAll(genJsonNumber) { num => - val parts = Parser.unsafeParse(Parser.JsonNumber.partsParser, num.asString) + val parts = + Parser.unsafeParse(Parser.JsonNumber.partsParser, num.asString) assert(parts.asString == num.asString) } - val regressions = List( - Json.JNumberStr("2E9"), - Json.JNumberStr("-9E+19")) + val regressions = List(Json.JNumberStr("2E9"), Json.JNumberStr("-9E+19")) regressions.foreach { n => law(n) @@ -68,7 +84,7 @@ class JsonTest extends AnyFunSuite { def law(te: Option[TypeEnv[Any]], t: Type, j: Json) = { val jsonCodec = te match { - case None => ValueToJson(_ => None) + case None => ValueToJson(_ => None) case Some(te) => ValueToJson(te.toDefinedType(_)) } val toJson = jsonCodec.toJson(t) @@ -84,13 +100,17 @@ class JsonTest extends AnyFunSuite { ej1 match { case Right(j1) => assert(Eq[Json].eqv(j1, j), s"$j1 != $j") - case Left(_) => () + case Left(_) => () } } - forAll(optTE, GenJson.arbJson.arbitrary) { case ((ote, tpe), json) => law(ote, tpe, json) } + forAll(optTE, GenJson.arbJson.arbitrary) { case ((ote, tpe), json) => + law(ote, tpe, json) + } - val regressions = List((None, Type.TyApply(Type.OptionType, Type.BoolType), Json.JBool.False)) + val regressions = List( + (None, Type.TyApply(Type.OptionType, Type.BoolType), Json.JBool.False) + ) regressions.foreach { case (te, t, j) => law(te, t, j) } } @@ -114,7 +134,7 @@ class JsonTest extends AnyFunSuite { def law(ote: Option[TypeEnv[Unit]], t: Type, v: Value) = { val jsonCodec = ote match { - case None => ValueToJson(_ => None) + case None => ValueToJson(_ => None) case Some(te) => ValueToJson(te.toDefinedType(_)) } val toJson = jsonCodec.toJson(t) @@ -130,7 +150,7 @@ class JsonTest extends AnyFunSuite { ej1 match { case Right(v1) => assert(v1 == v, s"$v1 != $v") - case Left(_) => () + case Left(_) => () } } @@ -143,7 +163,9 @@ class JsonTest extends AnyFunSuite { } test("some hand written cases round trip") { - val te = typeEnvOf(PackageName.parts("Test"), """ + val te = typeEnvOf( + PackageName.parts("Test"), + """ struct MyUnit # wrappers are removed @@ -153,19 +175,21 @@ struct MyPair(fst, snd) enum MyEither: L(left), R(right) enum MyNat: Z, S(prev: MyNat) -""") +""" + ) val jsonConv = ValueToJson(te.toDefinedType(_)) def stringToType(t: String): Type = { val tr = Parser.unsafeParse(TypeRef.parser, t) TypeRefConverter[cats.Id](tr) { cons => - te.referencedPackages.toList.flatMap { pack => - val const = Type.Const.Defined(pack, TypeName(cons)) - te.toDefinedType(const).map(_ => const) - } - .headOption - .getOrElse(Type.Const.predef(cons.asString)) + te.referencedPackages.toList + .flatMap { pack => + val const = Type.Const.Defined(pack, TypeName(cons)) + te.toDefinedType(const).map(_ => const) + } + .headOption + .getOrElse(Type.Const.predef(cons.asString)) } } @@ -186,9 +210,11 @@ enum MyNat: Z, S(prev: MyNat) case Right(j1) => assert(Eq[Json].eqv(j1, j), s"$j1 != $j") case Left(err) => fail(err.toString) } - case Left(err) => fail(s"could not handle to Json: $tpe, $t, $toV, $err") + case Left(err) => + fail(s"could not handle to Json: $tpe, $t, $toV, $err") } - case Left(err) => fail(s"could not handle to Value: $tpe, $t, $toJ, $err") + case Left(err) => + fail(s"could not handle to Value: $tpe, $t, $toJ, $err") } } @@ -196,7 +222,7 @@ enum MyNat: Z, S(prev: MyNat) val t = stringToType(tpe) jsonConv.supported(t) match { case Right(_) => fail(s"expected $tpe to be unsupported") - case Left(_) => succeed + case Left(_) => succeed } } @@ -210,7 +236,7 @@ enum MyNat: Z, S(prev: MyNat) assert(toJ.isRight) val j = stringToJson(json) toV(j) match { - case Left(_) => succeed + case Left(_) => succeed case Right(v) => fail(s"expected $json to be ill-typed: $v") } case Left(err) => fail(s"could not handle to Value: $tpe, $t, $err") diff --git a/core/src/test/scala/org/bykn/bosatsu/KindFormulaTest.scala b/core/src/test/scala/org/bykn/bosatsu/KindFormulaTest.scala index 59fa4ad1e..3d92dcf65 100644 --- a/core/src/test/scala/org/bykn/bosatsu/KindFormulaTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/KindFormulaTest.scala @@ -32,7 +32,10 @@ class KindFormulaTest extends AnyFunSuite { def testKind(teStr: String, shapes: Map[String, String]) = testKindEither(makeTE(teStr), shapes) - def testKindEither(te: Either[Any, TypeEnv[Kind.Arg]], shapes: Map[String, String]) = + def testKindEither( + te: Either[Any, TypeEnv[Kind.Arg]], + shapes: Map[String, String] + ) = te match { case Right(te) => shapes.foreach { case (n, vs) => @@ -205,9 +208,10 @@ struct Leib[a, b](cast: forall f. f[a] -> f[b]) ) ) } - + test("test Applicative example") { - testKind("""# + testKind( + """# # Represents the Applicative typeclass struct Fn[a: -*, b: +*] struct Unit @@ -220,7 +224,9 @@ struct Applicative( map2: forall a, b, c. f[a] -> f[b] -> (a -> b -> c) -> f[c], product: forall a, b. f[a] -> f[b] -> f[(a, b)]) -""", Map("Applicative" -> "(* -> *) -> *")) +""", + Map("Applicative" -> "(* -> *) -> *") + ) } test("linked list is allowed") { diff --git a/core/src/test/scala/org/bykn/bosatsu/LocationMapTest.scala b/core/src/test/scala/org/bykn/bosatsu/LocationMapTest.scala index 7fc3f32a0..1bc3ed6db 100644 --- a/core/src/test/scala/org/bykn/bosatsu/LocationMapTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/LocationMapTest.scala @@ -1,12 +1,17 @@ package org.bykn.bosatsu import org.scalacheck.{Arbitrary, Gen} -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.scalatest.funsuite.AnyFunSuite class LocationMapTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - PropertyCheckConfiguration(minSuccessful = if (Platform.isScalaJvm) 50000 else 100) + PropertyCheckConfiguration(minSuccessful = + if (Platform.isScalaJvm) 50000 else 100 + ) test("single line locations") { val singleLine: Gen[String] = @@ -30,7 +35,8 @@ class LocationMapTest extends AnyFunSuite { forAll { (str: String) => val lm = LocationMap(str) - val reconstruct = Iterator.iterate(0)(_ + 1) + val reconstruct = Iterator + .iterate(0)(_ + 1) .map(lm.getLine _) .takeWhile(_.isDefined) .collect { case Some(l) => l } @@ -39,7 +45,9 @@ class LocationMapTest extends AnyFunSuite { assert(reconstruct === str) } } - test("toLineCol is defined for all valid offsets, and getLine isDefined consistently") { + test( + "toLineCol is defined for all valid offsets, and getLine isDefined consistently" + ) { forAll { (s: String, offset: Int) => val lm = LocationMap(s) @@ -53,7 +61,8 @@ class LocationMapTest extends AnyFunSuite { case None => assert(offset == s.length) case Some(line) => assert(line.length >= col) - if (line.length == col) assert(offset == s.length || s(offset) == '\n') + if (line.length == col) + assert(offset == s.length || s(offset) == '\n') else assert(line(col) == s(offset)) } } @@ -67,7 +76,7 @@ class LocationMapTest extends AnyFunSuite { forAll { (s: String) => LocationMap(s).toLineCol(0) match { case Some(r) => assert(r == ((0, 0))) - case None => assert(s.isEmpty) + case None => assert(s.isEmpty) } } } diff --git a/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala b/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala index 1a4258401..a902d71c6 100644 --- a/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala +++ b/core/src/test/scala/org/bykn/bosatsu/MatchlessTests.scala @@ -2,7 +2,10 @@ package org.bykn.bosatsu import cats.data.NonEmptyList import org.scalacheck.{Arbitrary, Gen} -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{forAll, PropertyCheckConfiguration} +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import Identifier.{Bindable, Constructor} import rankn.DataRepr @@ -12,26 +15,28 @@ import org.scalatest.funsuite.AnyFunSuite class MatchlessTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - PropertyCheckConfiguration(minSuccessful = if (Platform.isScalaJvm) 1000 else 20) + PropertyCheckConfiguration(minSuccessful = + if (Platform.isScalaJvm) 1000 else 20 + ) type Fn = (PackageName, Constructor) => Option[DataRepr] - def fnFromTypeEnv[A](te: rankn.TypeEnv[A]): Fn = - { - // the list constructors *have* to be in scope or matching will generate - // bad code - case (PackageName.PredefName, Constructor("EmptyList")) => - Some(DataRepr.Enum(0, 0, List(0, 1))) - case (PackageName.PredefName, Constructor("NonEmptyList")) => - Some(DataRepr.Enum(1, 2, List(0, 1))) - case (pn, cons) => - te.getConstructor(pn, cons) - .map(_._1.dataRepr(cons)) - .orElse(Some(DataRepr.Struct(0))) - } + def fnFromTypeEnv[A](te: rankn.TypeEnv[A]): Fn = { + // the list constructors *have* to be in scope or matching will generate + // bad code + case (PackageName.PredefName, Constructor("EmptyList")) => + Some(DataRepr.Enum(0, 0, List(0, 1))) + case (PackageName.PredefName, Constructor("NonEmptyList")) => + Some(DataRepr.Enum(1, 2, List(0, 1))) + case (pn, cons) => + te.getConstructor(pn, cons) + .map(_._1.dataRepr(cons)) + .orElse(Some(DataRepr.Struct(0))) + } lazy val genInputs: Gen[(Bindable, RecursionKind, TypedExpr[Unit], Fn)] = - Generators.genPackage(Gen.const(()), 5) + Generators + .genPackage(Gen.const(()), 5) .flatMap { (m: Map[PackageName, Package.Typed[Unit]]) => val candidates = m.filter { case (_, t) => t.program.lets.nonEmpty } @@ -59,7 +64,9 @@ class MatchlessTest extends AnyFunSuite { val name = Identifier.Name("foo") val te = TypedExpr.Local(name, rankn.Type.IntType, ()) // this should not throw - val me = Matchless.fromLet(name, RecursionKind.Recursive, te)(fnFromTypeEnv(rankn.TypeEnv.empty)) + val me = Matchless.fromLet(name, RecursionKind.Recursive, te)( + fnFromTypeEnv(rankn.TypeEnv.empty) + ) assert(me != null) } @@ -83,14 +90,16 @@ class MatchlessTest extends AnyFunSuite { } test("Matchless.stopAt works") { - forAll(genNE(100, Gen.choose(-100, 100)), Arbitrary.arbitrary[Int => Boolean]) { (nel, fn) => + forAll( + genNE(100, Gen.choose(-100, 100)), + Arbitrary.arbitrary[Int => Boolean] + ) { (nel, fn) => val stopped = Matchless.stopAt(nel)(fn) if (fn(stopped.last)) { // none of the items before the last are true: assert(stopped.init.exists(fn) == false) - } - else { + } else { // none of them were true assert(stopped == nel) assert(nel.exists(fn) == false) @@ -105,31 +114,39 @@ class MatchlessTest extends AnyFunSuite { for { s <- size left <- Gen.listOfN(s, bytes) - sright <- Gen.choose(0, 2*s) - pat <- Gen.listOfN(sright, Arbitrary.arbitrary[Option[Byte => Option[Int]]]) + sright <- Gen.choose(0, 2 * s) + pat <- Gen.listOfN( + sright, + Arbitrary.arbitrary[Option[Byte => Option[Int]]] + ) } yield (left, pat) } import pattern.{SeqPattern, SeqPart, Splitter, Matcher} - def toSeqPat[A, B](pat: List[Option[A => Option[B]]]): SeqPattern[A => Option[B]] = + def toSeqPat[A, B]( + pat: List[Option[A => Option[B]]] + ): SeqPattern[A => Option[B]] = SeqPattern.fromList(pat.map { - case None => SeqPart.Wildcard - case Some(fn) =>SeqPart.Lit(fn) + case None => SeqPart.Wildcard + case Some(fn) => SeqPart.Lit(fn) }) val matcher = SeqPattern.matcher( Splitter.listSplitter(new Matcher[Byte => Option[Int], Byte, Int] { def apply(fn: Byte => Option[Int]) = fn - })) + }) + ) forAll(genArgs) { case (targ, pat) => val seqPat = toSeqPat(pat) val matchRes = matcher(seqPat)(targ) - val matchlessRes = Matchless.matchList(targ, + val matchlessRes = Matchless.matchList( + targ, pat.map { - case None => Left { (_: List[Byte]) => 0 } + case None => Left { (_: List[Byte]) => 0 } case Some(fn) => Right(fn) - }) + } + ) assert(matchlessRes == matchRes) } diff --git a/core/src/test/scala/org/bykn/bosatsu/MonadGen.scala b/core/src/test/scala/org/bykn/bosatsu/MonadGen.scala index c655acd3d..812e27e58 100644 --- a/core/src/test/scala/org/bykn/bosatsu/MonadGen.scala +++ b/core/src/test/scala/org/bykn/bosatsu/MonadGen.scala @@ -14,4 +14,4 @@ object MonadGen { def tailRecM[A, B](a: A)(fn: A => Gen[Either[A, B]]): Gen[B] = Gen.tailRecM(a)(fn) } -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/bykn/bosatsu/OperatorTest.scala b/core/src/test/scala/org/bykn/bosatsu/OperatorTest.scala index 817695225..34f6aa2b2 100644 --- a/core/src/test/scala/org/bykn/bosatsu/OperatorTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/OperatorTest.scala @@ -1,6 +1,6 @@ package org.bykn.bosatsu -import org.typelevel.paiges.{ Doc, Document } +import org.typelevel.paiges.{Doc, Document} import cats.parse.{Parser => P} @@ -14,7 +14,7 @@ class OperatorTest extends ParserTestBase { sealed abstract class F { def toFormula: Formula[String] = this match { - case F.Num(s) => Formula.Sym(s) + case F.Num(s) => Formula.Sym(s) case F.Form(Formula.Sym(n)) => n.toFormula case F.Form(Formula.Op(left, op, right)) => Formula.Op(F.Form(left).toFormula, op, F.Form(right).toFormula) @@ -28,8 +28,7 @@ class OperatorTest extends ParserTestBase { lazy val formP: P[F] = Operators.Formula .parser( - Parser - .integerString + Parser.integerString .map(F.Num(_)) .orElse(P.defer(formP.parensCut)) ) @@ -43,7 +42,11 @@ class OperatorTest extends ParserTestBase { } def parseSame(left: String, right: String) = - assert(Parser.unsafeParse(formP, left).toFormula == Parser.unsafeParse(formP, right).toFormula) + assert( + Parser.unsafeParse(formP, left).toFormula == Parser + .unsafeParse(formP, right) + .toFormula + ) test("we can parse integer formulas") { parseSame("1+2", "1 + 2") @@ -69,7 +72,8 @@ class OperatorTest extends ParserTestBase { } test("test operator precedence in real programs") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package Test operator + = add @@ -83,9 +87,13 @@ test = TestSuite("precedence", Assertion(1 * 2 * 3 == (1 * 2) * 3, "p1"), Assertion(1 + 2 % 3 == 1 + (2 % 3), "p1") ]) -"""), "Test", 3) +"""), + "Test", + 3 + ) - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package Test # this is non-associative so we can test order @@ -102,9 +110,14 @@ test = TestSuite("precedence", [ Assertion(1 *> 2 *> 3 == (1 *> 2) *> 3, "p1"), ]) -"""), "Test", 1) - - runBosatsuTest(List(""" +"""), + "Test", + 1 + ) + + runBosatsuTest( + List( + """ package T1 export operator +, operator *, operator == @@ -113,7 +126,7 @@ operator + = add operator * = times operator == = eq_Int """, - """ + """ package T2 from T1 import operator + as operator ++, `*`, `==` @@ -124,11 +137,16 @@ from T1 import operator + as operator ++, `*`, `==` test = TestSuite("import export", [ Assertion(1 +. (2 * 3) == 1 .+ (2 * 3), "p1"), Assertion(1 .+ 2 * 3 == (1 .+ 2) * 3, "p1") ]) -"""), "T2", 2) +""" + ), + "T2", + 2 + ) } test("test ternary operator precedence") { - runBosatsuTest(List(""" + runBosatsuTest( + List(""" package Test operator == = eq_Int @@ -147,6 +165,9 @@ test = TestSuite("precedence", Assertion(left1 == right1, "p1"), Assertion(left2 == right2, "p2"), ]) -"""), "Test", 2) +"""), + "Test", + 2 + ) } } diff --git a/core/src/test/scala/org/bykn/bosatsu/PackageTest.scala b/core/src/test/scala/org/bykn/bosatsu/PackageTest.scala index 32311fd78..06684366e 100644 --- a/core/src/test/scala/org/bykn/bosatsu/PackageTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/PackageTest.scala @@ -8,9 +8,12 @@ import org.scalatest.funsuite.AnyFunSuite class PackageTest extends AnyFunSuite with ParTest { - def resolveThenInfer(ps: Iterable[Package.Parsed]): ValidatedNel[PackageError, PackageMap.Inferred] = { + def resolveThenInfer( + ps: Iterable[Package.Parsed] + ): ValidatedNel[PackageError, PackageMap.Inferred] = { implicit val showInt: Show[Int] = Show.fromToString - PackageMap.resolveThenInfer(ps.toList.zipWithIndex.map(_.swap), Nil) + PackageMap + .resolveThenInfer(ps.toList.zipWithIndex.map(_.swap), Nil) .strictToValidated } @@ -22,7 +25,7 @@ class PackageTest extends AnyFunSuite with ParTest { def valid[A, B](v: Validated[A, B]) = v match { - case Validated.Valid(_) => succeed + case Validated.Valid(_) => succeed case Validated.Invalid(err) => fail(err.toString) } @@ -33,15 +36,13 @@ class PackageTest extends AnyFunSuite with ParTest { } test("simple package resolves") { - val p1 = parse( -""" + val p1 = parse(""" package Foo export main main = 1 """) - val p2 = parse( -""" + val p2 = parse(""" package Foo2 from Foo import main as mainFoo export main, @@ -49,8 +50,7 @@ export main, main = mainFoo """) - val p3 = parse( -""" + val p3 = parse(""" package Foo from Foo2 import main as mainFoo @@ -61,8 +61,7 @@ main = 1 valid(resolveThenInfer(List(p1, p2))) invalid(resolveThenInfer(List(p2, p3))) // loop here - val p4 = parse( -""" + val p4 = parse(""" package P4 from Foo2 import main as one @@ -72,8 +71,7 @@ main = add(one, 42) """) valid(resolveThenInfer(List(p1, p2, p4))) - val p5 = parse( -""" + val p5 = parse(""" package P5 export Option(), List(), head, tail @@ -99,8 +97,7 @@ def tail(list): case NonEmpty(_, t): Some(t) """) - val p6 = parse( -""" + val p6 = parse(""" package P6 from P5 import Option, List, NonEmpty, Empty, head, tail export data @@ -111,8 +108,7 @@ main = head(data) """) valid(resolveThenInfer(List(p5, p6))) - val p7 = parse( -""" + val p7 = parse(""" package P7 from P6 import data as p6_data from P5 import Option, List, NonEmpty as Cons, Empty as Nil, head, tail @@ -129,8 +125,7 @@ main = head(data1) assert(Package.predefPackage != null) - val p = parse( -""" + val p = parse(""" package UsePredef def maybeOne(x): @@ -146,8 +141,7 @@ main = maybeOne(42) test("test using a renamed type") { - val p1 = parse( -""" + val p1 = parse(""" package R1 export Foo(), mkFoo, takeFoo @@ -161,8 +155,7 @@ def takeFoo(foo): 0 """) - val p2 = parse( -""" + val p2 = parse(""" package R2 from R1 import Foo as Bar, mkFoo, takeFoo diff --git a/core/src/test/scala/org/bykn/bosatsu/ParTest.scala b/core/src/test/scala/org/bykn/bosatsu/ParTest.scala index e1d284513..49977f33a 100644 --- a/core/src/test/scala/org/bykn/bosatsu/ParTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/ParTest.scala @@ -1,6 +1,6 @@ package org.bykn.bosatsu -import org.scalatest.{BeforeAndAfterAll, Suite } +import org.scalatest.{BeforeAndAfterAll, Suite} trait ParTest extends BeforeAndAfterAll { self: Suite => diff --git a/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala b/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala index af4872339..42a6f0b14 100644 --- a/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/ParserTest.scala @@ -3,7 +3,10 @@ package org.bykn.bosatsu import cats.data.NonEmptyList import Parser.Combinators import org.scalacheck.{Arbitrary, Gen} -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.typelevel.paiges.{Doc, Document} import cats.implicits._ @@ -19,8 +22,7 @@ trait ParseFns { else if (s0.length == idx) { val s = s0 + "*" ("...(" + s.drop(idx - 20).take(20) + ")...") - } - else { + } else { val s = s0.updated(idx, '*') ("...(" + s.drop(idx - 20).take(30) + ")...") } @@ -30,7 +32,8 @@ trait ParseFns { else if (s1.isEmpty) s2 else if (s2.isEmpty) s1 else if (s1(0) == s2(0)) firstDiff(s1.tail, s2.tail) - else s"${s1(0).toInt}: ${s1.take(20)}... != ${s2(0).toInt}: ${s2.take(20)}..." + else + s"${s1(0).toInt}: ${s1.take(20)}... != ${s2(0).toInt}: ${s2.take(20)}..." } @@ -46,11 +49,16 @@ abstract class ParserTestBase extends AnyFunSuite with ParseFns { case Right((rest, t)) => val idx = if (rest == "") str.length else str.indexOf(rest) lazy val message = firstDiff(t.toString, expected.toString) - assert(t == expected, s"difference: $message, input syntax:\n\n\n$str\n\n") + assert( + t == expected, + s"difference: $message, input syntax:\n\n\n$str\n\n" + ) assert(idx == exidx) case Left(err) => val idx = err.failedAtOffset - fail(s"failed to parse: $str: at $idx in region ${region(str, idx)} with err: ${err}") + fail( + s"failed to parse: $str: at $idx in region ${region(str, idx)} with err: ${err}" + ) } def parseTestAll[T](p: P0[T], str: String, expected: T) = @@ -70,11 +78,15 @@ abstract class ParserTestBase extends AnyFunSuite with ParseFns { case Left(err) => val idx = err.failedAtOffset val diff = firstDiff(str, tstr) - fail(s"Diff: $diff.\nfailed to reparse: $tstr: $idx in region ${region(tstr, idx)} with err: ${err}") + fail( + s"Diff: $diff.\nfailed to reparse: $tstr: $idx in region ${region(tstr, idx)} with err: ${err}" + ) } case Left(err) => val idx = err.failedAtOffset - fail(s"failed to parse: $str: $idx in region ${region(str, idx)} with err: ${err}") + fail( + s"failed to parse: $str: $idx in region ${region(str, idx)} with err: ${err}" + ) } def roundTripExact[T: Document](p: P0[T], str: String) = @@ -86,7 +98,9 @@ abstract class ParserTestBase extends AnyFunSuite with ParseFns { assert(tstr == str) case Left(err) => val idx = err.failedAtOffset - fail(s"failed to parse: $str: $idx in region ${region(str, idx)} with err: ${err}") + fail( + s"failed to parse: $str: $idx in region ${region(str, idx)} with err: ${err}" + ) } def law[T: Document](p: P0[T])(t: T) = { @@ -101,12 +115,15 @@ abstract class ParserTestBase extends AnyFunSuite with ParseFns { fail(s"parsed $t to: $idx: ${region(str, idx)}") case Left(err) => val idx = err.failedAtOffset - def msg = s"failed to parse: $str: at $idx in region ${region(str, idx)} with err: ${err}" + def msg = + s"failed to parse: $str: at $idx in region ${region(str, idx)} with err: ${err}" assert(idx == atIdx, msg) } def config: PropertyCheckConfiguration = - PropertyCheckConfiguration(minSuccessful = if (Platform.isScalaJvm) 300 else 10) + PropertyCheckConfiguration(minSuccessful = + if (Platform.isScalaJvm) 300 else 10 + ) } class ParserTest extends ParserTestBase { @@ -128,10 +145,11 @@ class ParserTest extends ParserTestBase { def loop(b: String): Gen[String] = if (b.length <= 1) Gen.const(b) - else for { - s <- sep - tail <- loop(b.tail) - } yield s"${b.charAt(0)}$s$tail" + else + for { + s <- sep + tail <- loop(b.tail) + } yield s"${b.charAt(0)}$s$tail" loop(bstr).map(Opaque(_)) } @@ -148,11 +166,12 @@ class ParserTest extends ParserTestBase { try { Parser.unescape(str1) match { case Right(str2) => assert(str2 == str) - case Left(idx) => fail(s"failed at idx: $idx in $str: ${region(str, idx)}") + case Left(idx) => + fail(s"failed at idx: $idx in $str: ${region(str, idx)}") } - } - catch { - case t: Throwable => fail(s"failed to decode: $str1 from $str, exception: $t") + } catch { + case t: Throwable => + fail(s"failed to decode: $str1 from $str, exception: $t") } } @@ -205,7 +224,6 @@ class ParserTest extends ParserTestBase { val regressions = List(("'", '\'')) - regressions.foreach { case (s, c) => law(s, c) } } @@ -216,9 +234,11 @@ class ParserTest extends ParserTestBase { .interpolatedString('\'', P.string("${"), Json.parser, P.char('}')) .map(_.map { case Right((_, str)) => Right(str) - case Left(l) => Left(l) - }) - , str1, res) + case Left(l) => Left(l) + }), + str1, + res + ) // scala complains about things that look like interpolation strings that aren't interpolated val dollar = '$'.toString @@ -228,16 +248,22 @@ class ParserTest extends ParserTestBase { singleq(s"'foo\\$dollar{bar}'", List(Right(s"foo$dollar{bar}"))) // foo$bar is okay, it is only foo${bar} that needs to be escaped singleq(s"'foo${dollar}bar'", List(Right(s"foo${dollar}bar"))) - singleq(s"'foo$dollar{42}'", List(Right("foo"), Left(Json.JNumberStr("42")))) + singleq( + s"'foo$dollar{42}'", + List(Right("foo"), Left(Json.JNumberStr("42"))) + ) singleq(s"'$dollar{42}'", List(Left(Json.JNumberStr("42")))) - singleq(s"'$dollar{42}bar'", List(Left(Json.JNumberStr("42")), Right("bar"))) + singleq( + s"'$dollar{42}bar'", + List(Left(Json.JNumberStr("42")), Right("bar")) + ) } test("Identifier round trips") { forAll(Generators.identifierGen)(law(Identifier.parser)) - val examples = List("foo", "`bar`", "`bar foo`", - "`with \\`internal`", "operator +") + val examples = + List("foo", "`bar`", "`bar foo`", "`with \\`internal`", "operator +") examples.foreach(roundTrip(Identifier.parser, _)) } @@ -257,38 +283,55 @@ class ParserTest extends ParserTestBase { val str0 = ls.toString val str = str0.flatMap { case ',' => "," + (" " * spaceCount) - case c => c.toString + case c => c.toString } val listOfStr: P[List[String]] = P.string("List(") *> - Parser.integerString.nonEmptyList.map(_.toList) + Parser.integerString.nonEmptyList + .map(_.toList) .orElse(P.pure(Nil)) <* - P.char(')') + P.char(')') - parseTestAll( - listOfStr, - str, - ls.map(_.toString)) + parseTestAll(listOfStr, str, ls.map(_.toString)) } } test("we can parse dicts") { - val strDict = Parser.dictLikeParser(Parser.escapedString('\''), Parser.escapedString('\'')) + val strDict = Parser.dictLikeParser( + Parser.escapedString('\''), + Parser.escapedString('\'') + ) parseTestAll(strDict, "{}", Nil) parseTestAll(strDict, "{'a': 'b'}", List(("a", "b"))) parseTestAll(strDict, "{ 'a' : 'b' }", List(("a", "b"))) parseTestAll(strDict, "{'a' : 'b', 'c': 'd'}", List(("a", "b"), ("c", "d"))) - parseTestAll(strDict, "{'a' : 'b',\n'c': 'd'}", List(("a", "b"), ("c", "d"))) - parseTestAll(strDict, "{'a' : 'b',\n\t'c': 'd'}", List(("a", "b"), ("c", "d"))) - parseTestAll(strDict, "{'a' : 'b',\n 'c': 'd'}", List(("a", "b"), ("c", "d"))) - - case class WildDict(stringRepNoCurlies: List[String], original: List[(String, String)]) { + parseTestAll( + strDict, + "{'a' : 'b',\n'c': 'd'}", + List(("a", "b"), ("c", "d")) + ) + parseTestAll( + strDict, + "{'a' : 'b',\n\t'c': 'd'}", + List(("a", "b"), ("c", "d")) + ) + parseTestAll( + strDict, + "{'a' : 'b',\n 'c': 'd'}", + List(("a", "b"), ("c", "d")) + ) + + case class WildDict( + stringRepNoCurlies: List[String], + original: List[(String, String)] + ) { def stringRep: String = stringRepNoCurlies.mkString("{", "", "}") def addEntry(strings: List[String], k: String, v: String): WildDict = if (stringRepNoCurlies.isEmpty) WildDict(strings, (k, v) :: original) - else WildDict(strings ::: ("," :: stringRepNoCurlies), (k, v) :: original) + else + WildDict(strings ::: ("," :: stringRepNoCurlies), (k, v) :: original) } val genString = Arbitrary.arbitrary[String] @@ -319,7 +362,14 @@ class ParserTest extends ParserTestBase { test("we can parse RecordConstructors") { def check(str: String) = - roundTrip[Declaration](Declaration.recordConstructorP("", Declaration.varP, Declaration.varP.orElse(Declaration.lits)), str) + roundTrip[Declaration]( + Declaration.recordConstructorP( + "", + Declaration.varP, + Declaration.varP.orElse(Declaration.lits) + ), + str + ) check("Foo { bar }") check("Foo{bar}") @@ -355,7 +405,7 @@ class ParserTest extends ParserTestBase { check("Foo{x:1}") // from scalacheck - //check("Ze8lujlrbo {wlqOvp: {}}") + // check("Ze8lujlrbo {wlqOvp: {}}") } test("we can parse tuples") { @@ -368,9 +418,11 @@ class ParserTest extends ParserTestBase { case _ => ls.mkString("(", "," + pad, ")") } - parseTestAll(Parser.integerString.tupleOrParens, + parseTestAll( + Parser.integerString.tupleOrParens, str, - Right(ls.map(_.toString))) + Right(ls.map(_.toString)) + ) } // a single item is parsed as parens @@ -378,45 +430,71 @@ class ParserTest extends ParserTestBase { val spaceCount = spaceCnt0 & 7 val pad = " " * spaceCount val str = s"($it$pad)" - parseTestAll(Parser.integerString.tupleOrParens, - str, - Left(it.toString)) + parseTestAll(Parser.integerString.tupleOrParens, str, Left(it.toString)) } } test("we can parse blocks") { - val indy = OptIndent.block(Indy.lift(P.string("if foo")), Indy.lift(P.string("bar"))) + val indy = + OptIndent.block(Indy.lift(P.string("if foo")), Indy.lift(P.string("bar"))) val p = indy.run("") parseTestAll(p, "if foo: bar", ((), OptIndent.same(()))) parseTestAll(p, "if foo:\n\tbar", ((), OptIndent.paddedIndented(1, 4, ()))) - parseTestAll(p, "if foo:\n bar", ((), OptIndent.paddedIndented(1, 4, ()))) + parseTestAll( + p, + "if foo:\n bar", + ((), OptIndent.paddedIndented(1, 4, ())) + ) parseTestAll(p, "if foo:\n bar", ((), OptIndent.paddedIndented(1, 2, ()))) import Indy.IndyMethods val repeated = indy.nonEmptyList(Indy.lift(Parser.toEOL1)) val single = ((), OptIndent.notSame(Padding(1, Indented(2, ())))) - parseTestAll(repeated.run(""), "if foo:\n bar\nif foo:\n bar", - NonEmptyList.of(single, single)) + parseTestAll( + repeated.run(""), + "if foo:\n bar\nif foo:\n bar", + NonEmptyList.of(single, single) + ) // we can nest blocks - parseTestAll(OptIndent.block(Indy.lift(P.string("nest")), indy)(""), "nest: if foo: bar", - ((), OptIndent.same(((), OptIndent.same(()))))) - parseTestAll(OptIndent.block(Indy.lift(P.string("nest")), indy)(""), "nest:\n if foo: bar", - ((), OptIndent.paddedIndented(1, 2, ((), OptIndent.same(()))))) - parseTestAll(OptIndent.block(Indy.lift(P.string("nest")), indy)(""), "nest:\n if foo:\n bar", - ((), OptIndent.paddedIndented(1, 2, ((), OptIndent.paddedIndented(1, 2, ()))))) - - val simpleBlock = OptIndent.block(Indy.lift(Parser.lowerIdent <* Parser.maybeSpace), Indy.lift(Parser.lowerIdent)) + parseTestAll( + OptIndent.block(Indy.lift(P.string("nest")), indy)(""), + "nest: if foo: bar", + ((), OptIndent.same(((), OptIndent.same(())))) + ) + parseTestAll( + OptIndent.block(Indy.lift(P.string("nest")), indy)(""), + "nest:\n if foo: bar", + ((), OptIndent.paddedIndented(1, 2, ((), OptIndent.same(())))) + ) + parseTestAll( + OptIndent.block(Indy.lift(P.string("nest")), indy)(""), + "nest:\n if foo:\n bar", + ( + (), + OptIndent.paddedIndented(1, 2, ((), OptIndent.paddedIndented(1, 2, ()))) + ) + ) + + val simpleBlock = OptIndent + .block( + Indy.lift(Parser.lowerIdent <* Parser.maybeSpace), + Indy.lift(Parser.lowerIdent) + ) .nonEmptyList(Indy.toEOLIndent) - val sbRes = NonEmptyList.of(("x1", OptIndent.paddedIndented(1, 2, "x2")), - ("y1", OptIndent.paddedIndented(1, 3, "y2"))) + val sbRes = NonEmptyList.of( + ("x1", OptIndent.paddedIndented(1, 2, "x2")), + ("y1", OptIndent.paddedIndented(1, 3, "y2")) + ) parseTestAll(simpleBlock(""), "x1:\n x2\ny1:\n y2", sbRes) - parseTestAll(OptIndent.block(Indy.lift(Parser.lowerIdent), simpleBlock)(""), + parseTestAll( + OptIndent.block(Indy.lift(Parser.lowerIdent), simpleBlock)(""), "block:\n x1:\n x2\n y1:\n y2", - ("block", OptIndent.paddedIndented(1, 2, sbRes))) + ("block", OptIndent.paddedIndented(1, 2, sbRes)) + ) } def trName(s: String): TypeRef.TypeName = @@ -426,23 +504,62 @@ class ParserTest extends ParserTestBase { parseTestAll(TypeRef.parser, "foo", TypeRef.TypeVar("foo")) parseTestAll(TypeRef.parser, "Foo", trName("Foo")) - parseTestAll(TypeRef.parser, "forall a. a", - TypeRef.TypeForAll(NonEmptyList.of((TypeRef.TypeVar("a"), None)), TypeRef.TypeVar("a"))) - parseTestAll(TypeRef.parser, "forall a, b. f[a] -> f[b]", - TypeRef.TypeForAll(NonEmptyList.of((TypeRef.TypeVar("a"), None), (TypeRef.TypeVar("b"), None)), + parseTestAll( + TypeRef.parser, + "forall a. a", + TypeRef.TypeForAll( + NonEmptyList.of((TypeRef.TypeVar("a"), None)), + TypeRef.TypeVar("a") + ) + ) + parseTestAll( + TypeRef.parser, + "forall a, b. f[a] -> f[b]", + TypeRef.TypeForAll( + NonEmptyList + .of((TypeRef.TypeVar("a"), None), (TypeRef.TypeVar("b"), None)), TypeRef.TypeArrow( - TypeRef.TypeApply(TypeRef.TypeVar("f"), NonEmptyList.of(TypeRef.TypeVar("a"))), - TypeRef.TypeApply(TypeRef.TypeVar("f"), NonEmptyList.of(TypeRef.TypeVar("b")))))) + TypeRef.TypeApply( + TypeRef.TypeVar("f"), + NonEmptyList.of(TypeRef.TypeVar("a")) + ), + TypeRef.TypeApply( + TypeRef.TypeVar("f"), + NonEmptyList.of(TypeRef.TypeVar("b")) + ) + ) + ) + ) roundTrip(TypeRef.parser, "forall a, b. f[a] -> f[b]") roundTrip(TypeRef.parser, "(forall a, b. f[a]) -> f[b]") roundTrip(TypeRef.parser, "(forall a, b. f[a])[Int]") // apply a type - parseTestAll(TypeRef.parser, "Foo -> Bar", TypeRef.TypeArrow(trName("Foo"), trName("Bar"))) - parseTestAll(TypeRef.parser, "Foo -> Bar -> baz", - TypeRef.TypeArrow(trName("Foo"), TypeRef.TypeArrow(trName("Bar"), TypeRef.TypeVar("baz")))) - parseTestAll(TypeRef.parser, "(Foo -> Bar) -> baz", - TypeRef.TypeArrow(TypeRef.TypeArrow(trName("Foo"), trName("Bar")), TypeRef.TypeVar("baz"))) - parseTestAll(TypeRef.parser, "Foo[Bar]", TypeRef.TypeApply(trName("Foo"), NonEmptyList.of(trName("Bar")))) + parseTestAll( + TypeRef.parser, + "Foo -> Bar", + TypeRef.TypeArrow(trName("Foo"), trName("Bar")) + ) + parseTestAll( + TypeRef.parser, + "Foo -> Bar -> baz", + TypeRef.TypeArrow( + trName("Foo"), + TypeRef.TypeArrow(trName("Bar"), TypeRef.TypeVar("baz")) + ) + ) + parseTestAll( + TypeRef.parser, + "(Foo -> Bar) -> baz", + TypeRef.TypeArrow( + TypeRef.TypeArrow(trName("Foo"), trName("Bar")), + TypeRef.TypeVar("baz") + ) + ) + parseTestAll( + TypeRef.parser, + "Foo[Bar]", + TypeRef.TypeApply(trName("Foo"), NonEmptyList.of(trName("Bar"))) + ) forAll(Generators.typeRefGen) { tref => parseTestAll(TypeRef.parser, tref.toDoc.render(80), tref) @@ -457,19 +574,37 @@ class ParserTest extends ParserTestBase { val varA = TyVar(Var.Bound("a")) val varB = TyVar(Var.Bound("b")) - val FooBarBar = TyConst(Const.Defined(PackageName.parts("Foo", "Bar"), TypeName(Identifier.Constructor("Bar")))) + val FooBarBar = TyConst( + Const.Defined( + PackageName.parts("Foo", "Bar"), + TypeName(Identifier.Constructor("Bar")) + ) + ) check("a", varA) check("Foo/Bar::Bar", FooBarBar) check("a -> Foo/Bar::Bar", Fun(varA, FooBarBar)) - check("forall a, b. Foo/Bar::Bar[a, b]", Type.forAll(List((Var.Bound("a"), Kind.Type), (Var.Bound("b"), Kind.Type)), TyApply(TyApply(FooBarBar, varA), varB))) - check("forall a. forall b. Foo/Bar::Bar[a, b]", Type.forAll(List((Var.Bound("a"), Kind.Type), (Var.Bound("b"), Kind.Type)), TyApply(TyApply(FooBarBar, varA), varB))) + check( + "forall a, b. Foo/Bar::Bar[a, b]", + Type.forAll( + List((Var.Bound("a"), Kind.Type), (Var.Bound("b"), Kind.Type)), + TyApply(TyApply(FooBarBar, varA), varB) + ) + ) + check( + "forall a. forall b. Foo/Bar::Bar[a, b]", + Type.forAll( + List((Var.Bound("a"), Kind.Type), (Var.Bound("b"), Kind.Type)), + TyApply(TyApply(FooBarBar, varA), varB) + ) + ) check("(a)", varA) check("(a, b)", Tuple(List(varA, varB))) } test("we can parse python style list expressions") { val pident = Parser.lowerIdent - implicit val stringDoc: Document[String] = Document.instance[String](Doc.text(_)) + implicit val stringDoc: Document[String] = + Document.instance[String](Doc.text(_)) val llp = ListLang.parser(pident, pident, pident) roundTrip(llp, "[a]") @@ -489,9 +624,22 @@ class ParserTest extends ParserTestBase { test("we can parse operators") { val singleToks = List( - "+", "-", "*", "!", "$", "%", - "^", "&", "*", "|", "?", "/", "<", - ">", "~") + "+", + "-", + "*", + "!", + "$", + "%", + "^", + "&", + "*", + "|", + "?", + "/", + "<", + ">", + "~" + ) val withEq = "=" :: singleToks val allLen2 = (withEq, withEq).mapN(_ + _) @@ -515,9 +663,8 @@ class ParserTest extends ParserTestBase { } } -/** - * This is a separate class since some of these are very slow - */ +/** This is a separate class since some of these are very slow + */ class SyntaxParseTest extends ParserTestBase { implicit val generatorDrivenConfig: PropertyCheckConfiguration = config @@ -526,11 +673,18 @@ class SyntaxParseTest extends ParserTestBase { Declaration.Var(Identifier.Name(n)) test("we can parse comments") { - val gen = Generators.commentGen(Generators.padding(Generators.genDeclaration(0), 1)) + val gen = + Generators.commentGen(Generators.padding(Generators.genDeclaration(0), 1)) forAll(gen) { comment => - parseTestAll(CommentStatement.parser(i => Padding.parser(Declaration.parser(i))).run(""), - Document[CommentStatement[Padding[Declaration]]].document(comment).render(80), - comment) + parseTestAll( + CommentStatement + .parser(i => Padding.parser(Declaration.parser(i))) + .run(""), + Document[CommentStatement[Padding[Declaration]]] + .document(comment) + .render(80), + comment + ) } val commentLit = """#foo @@ -541,8 +695,12 @@ class SyntaxParseTest extends ParserTestBase { Declaration.parser(""), commentLit, Declaration.CommentNB( - CommentStatement(NonEmptyList.of("foo", "bar"), - Padding(1, Declaration.Literal(Lit.fromInt(1)))))) + CommentStatement( + NonEmptyList.of("foo", "bar"), + Padding(1, Declaration.Literal(Lit.fromInt(1))) + ) + ) + ) val parensComment = """(#foo #bar @@ -551,9 +709,15 @@ class SyntaxParseTest extends ParserTestBase { parseTestAll( Declaration.parser(""), parensComment, - Declaration.Parens(Declaration.CommentNB( - CommentStatement(NonEmptyList.of("foo", "bar"), - Padding(1, Declaration.Literal(Lit.fromInt(1))))))) + Declaration.Parens( + Declaration.CommentNB( + CommentStatement( + NonEmptyList.of("foo", "bar"), + Padding(1, Declaration.Literal(Lit.fromInt(1))) + ) + ) + ) + ) } test("we can parse Lit.Integer") { @@ -563,11 +727,19 @@ class SyntaxParseTest extends ParserTestBase { } test("we can parse DefStatement") { - forAll(Generators.defGen(Generators.optIndent(Generators.genDeclaration(0)))) { defn => + forAll( + Generators.defGen(Generators.optIndent(Generators.genDeclaration(0))) + ) { defn => parseTestAll[DefStatement[Pattern.Parsed, OptIndent[Declaration]]]( - DefStatement.parser(Pattern.bindParser, Parser.maybeSpace.with1 *> OptIndent.indy(Declaration.parser).run("")), - Document[DefStatement[Pattern.Parsed, OptIndent[Declaration]]].document(defn).render(80), - defn) + DefStatement.parser( + Pattern.bindParser, + Parser.maybeSpace.with1 *> OptIndent.indy(Declaration.parser).run("") + ), + Document[DefStatement[Pattern.Parsed, OptIndent[Declaration]]] + .document(defn) + .render(80), + defn + ) } val defWithComment = """def foo(a): @@ -577,81 +749,160 @@ foo""" parseTestAll( Declaration.parser(""), defWithComment, - Declaration.DefFn(DefStatement(Identifier.Name("foo"), None, - NonEmptyList.one(NonEmptyList.one(Pattern.Var(Identifier.Name("a")))), None, - (OptIndent.paddedIndented(1, 2, Declaration.CommentNB(CommentStatement(NonEmptyList.of(" comment here"), - Padding(0, mkVar("a"))))), - Padding(0, mkVar("foo")))))) + Declaration.DefFn( + DefStatement( + Identifier.Name("foo"), + None, + NonEmptyList.one(NonEmptyList.one(Pattern.Var(Identifier.Name("a")))), + None, + ( + OptIndent.paddedIndented( + 1, + 2, + Declaration.CommentNB( + CommentStatement( + NonEmptyList.of(" comment here"), + Padding(0, mkVar("a")) + ) + ) + ), + Padding(0, mkVar("foo")) + ) + ) + ) + ) roundTrip(Declaration.parser(""), defWithComment) // Here is a pretty brutal randomly generated case - roundTrip(Declaration.parser(""), -"""def uwr(dw: h6lmZhgg) -> forall lnNR. Z5syis -> Mhgm: + roundTrip( + Declaration.parser(""), + """def uwr(dw: h6lmZhgg) -> forall lnNR. Z5syis -> Mhgm: -349743008 -foo""") +foo""" + ) } test("we can parse BindingStatement") { val dp = Declaration.parser("") - parseTestAll(dp, + parseTestAll( + dp, """foo = 5 5""", - Declaration.Binding(BindingStatement(Pattern.Var(Identifier.Name("foo")), Declaration.Literal(Lit.fromInt(5)), - Padding(1, Declaration.Literal(Lit.fromInt(5)))))) - + Declaration.Binding( + BindingStatement( + Pattern.Var(Identifier.Name("foo")), + Declaration.Literal(Lit.fromInt(5)), + Padding(1, Declaration.Literal(Lit.fromInt(5))) + ) + ) + ) - roundTrip(dp, -"""# + roundTrip( + dp, + """# Pair(_, x) = z -x""") +x""" + ) } test("we can parse any Apply") { import Declaration._ - import ApplyKind.{Dot => ADot, Parens => AParens } + import ApplyKind.{Dot => ADot, Parens => AParens} - parseTestAll(parser(""), + parseTestAll( + parser(""), "x(f)", - Apply(mkVar("x"), NonEmptyList.of(mkVar("f")), AParens)) + Apply(mkVar("x"), NonEmptyList.of(mkVar("f")), AParens) + ) - parseTestAll(parser(""), + parseTestAll( + parser(""), "f.x()", - Apply(mkVar("x"), NonEmptyList.of(mkVar("f")), ADot)) + Apply(mkVar("x"), NonEmptyList.of(mkVar("f")), ADot) + ) - parseTestAll(parser(""), + parseTestAll( + parser(""), "f(foo).x()", - Apply(mkVar("x"), NonEmptyList.of(Apply(mkVar("f"), NonEmptyList.of(mkVar("foo")), AParens)), ADot)) + Apply( + mkVar("x"), + NonEmptyList.of( + Apply(mkVar("f"), NonEmptyList.of(mkVar("foo")), AParens) + ), + ADot + ) + ) - parseTestAll(parser(""), + parseTestAll( + parser(""), "f.foo(x)", // foo(f, x) - Apply(mkVar("foo"), NonEmptyList.of(mkVar("f"), mkVar("x")), ADot)) + Apply(mkVar("foo"), NonEmptyList.of(mkVar("f"), mkVar("x")), ADot) + ) - parseTestAll(parser(""), + parseTestAll( + parser(""), "(\\x -> x)(f)", - Apply(Parens(Lambda(NonEmptyList.of(Pattern.Var(Identifier.Name("x"))), mkVar("x"))), NonEmptyList.of(mkVar("f")), AParens)) + Apply( + Parens( + Lambda(NonEmptyList.of(Pattern.Var(Identifier.Name("x"))), mkVar("x")) + ), + NonEmptyList.of(mkVar("f")), + AParens + ) + ) - parseTestAll(parser(""), + parseTestAll( + parser(""), "((\\x -> x)(f))", - Parens(Apply(Parens(Lambda(NonEmptyList.of(Pattern.Var(Identifier.Name("x"))), mkVar("x"))), NonEmptyList.of(mkVar("f")), AParens))) + Parens( + Apply( + Parens( + Lambda( + NonEmptyList.of(Pattern.Var(Identifier.Name("x"))), + mkVar("x") + ) + ), + NonEmptyList.of(mkVar("f")), + AParens + ) + ) + ) // bare lambda - parseTestAll(parser(""), + parseTestAll( + parser(""), "((x -> x)(f))", - Parens(Apply(Parens(Lambda(NonEmptyList.of(Pattern.Var(Identifier.Name("x"))), mkVar("x"))), NonEmptyList.of(mkVar("f")), AParens))) + Parens( + Apply( + Parens( + Lambda( + NonEmptyList.of(Pattern.Var(Identifier.Name("x"))), + mkVar("x") + ) + ), + NonEmptyList.of(mkVar("f")), + AParens + ) + ) + ) - val expected = Apply(Parens(Parens(Lambda(NonEmptyList.of(Pattern.Var(Identifier.Name("x"))), mkVar("x")))), NonEmptyList.of(mkVar("f")), AParens) - parseTestAll(parser(""), - "((\\x -> x))(f)", - expected) + val expected = Apply( + Parens( + Parens( + Lambda(NonEmptyList.of(Pattern.Var(Identifier.Name("x"))), mkVar("x")) + ) + ), + NonEmptyList.of(mkVar("f")), + AParens + ) + parseTestAll(parser(""), "((\\x -> x))(f)", expected) - parseTestAll(parser(""), - expected.toDoc.render(80), - expected) + parseTestAll(parser(""), expected.toDoc.render(80), expected) } @@ -693,7 +944,7 @@ x""") test("Declaration.toPattern works for all Pattern-like declarations") { def law1(dec: Declaration.NonBinding) = { Declaration.toPattern(dec) match { - case None => fail("expected to convert to pattern") + case None => fail("expected to convert to pattern") case Some(pat) => // if we convert to string this parses the same as a pattern: val decStr = dec.toDoc.render(80) @@ -708,8 +959,18 @@ x""") import Identifier.{Name, Operator, Constructor} // this operator application can be a pattern List( - ApplyOp(Var(Name("q")),Operator("|"),Var(Name("npzma"))), - ApplyOp(Parens(ApplyOp(Parens(Literal(Lit.Str("igyimc"))),Operator("|"),Var(Name("ncf5Eo9")))),Operator("|"),Var(Constructor("K"))) + ApplyOp(Var(Name("q")), Operator("|"), Var(Name("npzma"))), + ApplyOp( + Parens( + ApplyOp( + Parens(Literal(Lit.Str("igyimc"))), + Operator("|"), + Var(Name("ncf5Eo9")) + ) + ), + Operator("|"), + Var(Constructor("K")) + ) ) } @@ -719,10 +980,12 @@ x""") val decStr = dec.toDoc.render(80) val parsePat = optionParse(Pattern.matchParser, decStr) (Declaration.toPattern(dec), parsePat) match { - case (None, None) => succeed + case (None, None) => succeed case (Some(p0), Some(p1)) => assert(p0 == p1) - case (None, Some(_)) => fail(s"toPattern failed, but parsed $decStr to: $parsePat") - case (Some(p), None) => fail(s"toPattern succeeded: $p but pattern parse failed") + case (None, Some(_)) => + fail(s"toPattern failed, but parsed $decStr to: $parsePat") + case (Some(p), None) => + fail(s"toPattern succeeded: $p but pattern parse failed") } } @@ -730,13 +993,13 @@ x""") forAll(Generators.genNonBinding(5))(law2(_)) regressions.foreach(law2(_)) - def testEqual(decl: String) = { - val dec = unsafeParse(Declaration.parser(""), decl).asInstanceOf[Declaration.NonBinding] + val dec = unsafeParse(Declaration.parser(""), decl) + .asInstanceOf[Declaration.NonBinding] val patt = unsafeParse(Pattern.matchParser, decl) Declaration.toPattern(dec) match { case Some(p2) => assert(p2 == patt) - case None => fail(s"could not convert $decl to pattern") + case None => fail(s"could not convert $decl to pattern") } } @@ -749,28 +1012,70 @@ x""") test("we can parse bind") { import Declaration._ - parseTestAll(parser(""), + parseTestAll( + parser(""), """x = 4 x""", - Binding(BindingStatement(Pattern.Var(Identifier.Name("x")), Literal(Lit.fromInt(4)), Padding(0, mkVar("x"))))) + Binding( + BindingStatement( + Pattern.Var(Identifier.Name("x")), + Literal(Lit.fromInt(4)), + Padding(0, mkVar("x")) + ) + ) + ) - parseTestAll(parser(""), + parseTestAll( + parser(""), """x = foo(4) x""", - Binding(BindingStatement(Pattern.Var(Identifier.Name("x")), Apply(mkVar("foo"), NonEmptyList.of(Literal(Lit.fromInt(4))), ApplyKind.Parens), Padding(1, mkVar("x"))))) + Binding( + BindingStatement( + Pattern.Var(Identifier.Name("x")), + Apply( + mkVar("foo"), + NonEmptyList.of(Literal(Lit.fromInt(4))), + ApplyKind.Parens + ), + Padding(1, mkVar("x")) + ) + ) + ) - parseTestAll(parser(""), + parseTestAll( + parser(""), """x = foo(4) # x is really great x""", - Binding(BindingStatement(Pattern.Var(Identifier.Name("x")),Apply(mkVar("foo"),NonEmptyList.of(Literal(Lit.fromInt(4))), ApplyKind.Parens),Padding(0,CommentNB(CommentStatement(NonEmptyList.of(" x is really great"),Padding(0,mkVar("x")))))))) + Binding( + BindingStatement( + Pattern.Var(Identifier.Name("x")), + Apply( + mkVar("foo"), + NonEmptyList.of(Literal(Lit.fromInt(4))), + ApplyKind.Parens + ), + Padding( + 0, + CommentNB( + CommentStatement( + NonEmptyList.of(" x is really great"), + Padding(0, mkVar("x")) + ) + ) + ) + ) + ) + ) // allow indentation after = - roundTrip(parser(""), + roundTrip( + parser(""), """x = | foo - |x""".stripMargin) + |x""".stripMargin + ) } test("we can parse if") { @@ -780,103 +1085,139 @@ x""", val liftVar0 = Parser.Indy.lift(varP: P[NonBinding]) val parser0 = ifElseP(liftVar0, liftVar)("") - roundTrip[Declaration](parser0, + roundTrip[Declaration]( + parser0, """if w: x else: - y""") + y""" + ) - roundTrip[Declaration](parser0, + roundTrip[Declaration]( + parser0, """if w: | x |else: - | y""".stripMargin) + | y""".stripMargin + ) - roundTrip(parser(""), + roundTrip( + parser(""), """if eq_Int(x, 3): x else: - y""") + y""" + ) - expectFail(parser0, + expectFail( + parser0, """if x: x else - y""", 18) + y""", + 18 + ) - expectFail(parser0, + expectFail( + parser0, """if x: x -else y""", 13) +else y""", + 13 + ) - expectFail(parser(""), + expectFail( + parser(""), """if x: x else - y""", 18) + y""", + 18 + ) - expectFail(parser(""), + expectFail( + parser(""), """if x: x -else y""", 13) +else y""", + 13 + ) - expectFail(parser(""), + expectFail( + parser(""), """if f: 0 -else 1""", 13) +else 1""", + 13 + ) - roundTrip(parser(""), + roundTrip( + parser(""), """if eq_Int(x, 3): x elif foo: z else: - y""") + y""" + ) - roundTrip[Declaration](parser0, + roundTrip[Declaration]( + parser0, """if w: x -else: y""") - roundTrip(parser(""), +else: y""" + ) + roundTrip( + parser(""), """if eq_Int(x, 3): x -else: y""") +else: y""" + ) - roundTrip(parser(""), + roundTrip( + parser(""), """if eq_Int(x, 3): x elif foo: z -else: y""") +else: y""" + ) } test("we can parse a match") { val liftVar = Parser.Indy.lift(Declaration.varP: P[Declaration]) val liftVar0 = Parser.Indy.lift(Declaration.varP: P[Declaration.NonBinding]) - roundTrip[Declaration](Declaration.matchP(liftVar0, liftVar)(""), -"""match x: + roundTrip[Declaration]( + Declaration.matchP(liftVar0, liftVar)(""), + """match x: case y: z case w: - r""") - roundTrip(Declaration.parser(""), -"""match 1: + r""" + ) + roundTrip( + Declaration.parser(""), + """match 1: case Foo(a, b): a.plus(b) case Bar: - 42""") - roundTrip(Declaration.parser(""), - -"""match 1: + 42""" + ) + roundTrip( + Declaration.parser(""), + """match 1: case (a, b): a.plus(b) case (): - 42""") + 42""" + ) - roundTrip(Declaration.parser(""), - -"""match 1: + roundTrip( + Declaration.parser(""), + """match 1: case (a, (b, c)): a.plus(b).plus(e) case (1,): - 42""") + 42""" + ) - roundTrip(Declaration.parser(""), -"""match 1: + roundTrip( + Declaration.parser(""), + """match 1: case Foo(a, b): a.plus(b) case Bar: @@ -884,20 +1225,24 @@ else: y""") case True: 100 case False: - 99""") + 99""" + ) - roundTrip(Declaration.parser(""), -"""foo(1, match 2: + roundTrip( + Declaration.parser(""), + """foo(1, match 2: case Foo: foo case Bar: # this is the bar case - bar, 100)""") + bar, 100)""" + ) - roundTrip(Declaration.parser(""), -"""if (match 2: + roundTrip( + Declaration.parser(""), + """if (match 2: case Foo: foo @@ -907,93 +1252,127 @@ else: y""") bar): 1 else: - 2""") + 2""" + ) - roundTrip(Declaration.parser(""), -"""if True: + roundTrip( + Declaration.parser(""), + """if True: match 1: case Foo(f): 1 else: - 100""") + 100""" + ) - roundTrip(Declaration.parser(""), -"""match x: + roundTrip( + Declaration.parser(""), + """match x: case Bar(_, _): - 10""") + 10""" + ) - roundTrip(Declaration.parser(""), -"""match x: + roundTrip( + Declaration.parser(""), + """match x: case Bar(_, _): if True: 0 - else: 10""") + else: 10""" + ) - roundTrip(Declaration.parser(""), -"""match x: + roundTrip( + Declaration.parser(""), + """match x: case Bar(_, _): if True: 0 - else: 10""") + else: 10""" + ) - roundTrip(Declaration.parser(""), -"""match x: + roundTrip( + Declaration.parser(""), + """match x: case []: 0 case [x]: 1 - case _: 2""") + case _: 2""" + ) - roundTrip(Declaration.parser(""), -"""Foo(x) = bar -x""") + roundTrip( + Declaration.parser(""), + """Foo(x) = bar +x""" + ) - roundTrip(Declaration.parser(""), -"""Foo { x } = bar -x""") + roundTrip( + Declaration.parser(""), + """Foo { x } = bar +x""" + ) - roundTrip(Declaration.parser(""), -"""Foo { x } = Foo{x:1} -x""") + roundTrip( + Declaration.parser(""), + """Foo { x } = Foo{x:1} +x""" + ) - roundTrip(Declaration.parser(""), -"""match x: - case Some(_) | None: 1""") + roundTrip( + Declaration.parser(""), + """match x: + case Some(_) | None: 1""" + ) - roundTrip(Declaration.parser(""), -"""match x: + roundTrip( + Declaration.parser(""), + """match x: case Some(_) | None: 1 case y: y - case [x | y, _]: z""") - + case [x | y, _]: z""" + ) - roundTrip(Declaration.parser(""), -"""Foo(x) | Bar(x) = bar -x""") + roundTrip( + Declaration.parser(""), + """Foo(x) | Bar(x) = bar +x""" + ) - roundTrip(Declaration.parser(""), -"""(x: Int) = bar -x""") - roundTrip(Declaration.parser(""), -"""x: Int = bar -x""") + roundTrip( + Declaration.parser(""), + """(x: Int) = bar +x""" + ) + roundTrip( + Declaration.parser(""), + """x: Int = bar +x""" + ) } test("we allow extra indentation on elif and else for better alignment") { - roundTrip(Declaration.parser(""), + roundTrip( + Declaration.parser(""), """z = if w: | x | else: | y - |z""".stripMargin) + |z""".stripMargin + ) - roundTrip(Declaration.parser(""), + roundTrip( + Declaration.parser(""), """z = if w: x | elif y: z | else: quux - |z""".stripMargin) + |z""".stripMargin + ) } test("we can parse declaration lists") { - val ll = ListLang.parser(Declaration.parser(""), Declaration.nonBindingParserNoTern(""), Pattern.matchParser) + val ll = ListLang.parser( + Declaration.parser(""), + Declaration.nonBindingParserNoTern(""), + Pattern.matchParser + ) roundTrip(Declaration.parser(""), "[]") roundTrip(Declaration.parser(""), "[1]") @@ -1010,13 +1389,18 @@ x""") roundTrip(ll, "[x for x in range(4) if x.eq_Int(2)]") roundTrip(ListLang.SpliceOrItem.parser(Declaration.parser("")), "a") roundTrip(ListLang.SpliceOrItem.parser(Declaration.parser("")), "foo(a, b)") - roundTrip(ListLang.SpliceOrItem.parser(Declaration.parser("")), "*foo(a, b)") + roundTrip( + ListLang.SpliceOrItem.parser(Declaration.parser("")), + "*foo(a, b)" + ) roundTrip(Declaration.parser(""), "[x for y in [1, 2]]") roundTrip(Declaration.parser(""), "[x for y in [1, 2] if foo]") } test("we can parse any Declaration") { - forAll(Generators.genDeclaration(5))(law(Declaration.parser("").map(_.replaceRegions(emptyRegion)))) + forAll(Generators.genDeclaration(5))( + law(Declaration.parser("").map(_.replaceRegions(emptyRegion))) + ) def decl(s: String) = roundTrip(Declaration.parser(""), s) @@ -1064,26 +1448,35 @@ x""") } test("we can parse any Statement") { - forAll(Generators.genStatements(4, 10))(law(Statement.parser.map(_.map(_.replaceRegions(emptyRegion))))) - - roundTrip(Statement.parser, -"""# -def foo(x): x""") - - roundTrip(Statement.parser, -"""# + forAll(Generators.genStatements(4, 10))( + law(Statement.parser.map(_.map(_.replaceRegions(emptyRegion)))) + ) + + roundTrip( + Statement.parser, + """# +def foo(x): x""" + ) + + roundTrip( + Statement.parser, + """# def foo(x): - x""") + x""" + ) - roundTrip(Statement.parser, -"""# + roundTrip( + Statement.parser, + """# operator + = plus x = 1+2 -""") +""" + ) - roundTrip(Statement.parser, -"""# header + roundTrip( + Statement.parser, + """# header y = if eq_Int(x, 2): True else: @@ -1097,10 +1490,12 @@ fn = \x, y -> x.plus(y) x = ( foo ) -""") +""" + ) - roundTrip(Statement.parser, -"""# header + roundTrip( + Statement.parser, + """# header def foo(x: forall f. f[a] -> f[b], y: a) -> b: x(y) @@ -1109,11 +1504,13 @@ fn = \x, y -> x.plus(y) x = ( foo ) -""") +""" + ) // we can add spaces at the end of the file - roundTrip(Statement.parser, -"""# header + roundTrip( + Statement.parser, + """# header def foo(x: forall f. f[a] -> f[b], y: a) -> b: x(y) @@ -1121,78 +1518,99 @@ def foo(x: forall f. f[a] -> f[b], y: a) -> b: fn = \x, y -> x.plus(y) x = ( foo ) - """) + """ + ) - roundTrip(Statement.parser, -"""# + roundTrip( + Statement.parser, + """# x = Pair([], b) -""") +""" + ) - roundTrip(Statement.parser, -"""# + roundTrip( + Statement.parser, + """# Pair(x, _) = Pair([], b) -""") +""" + ) - roundTrip(Statement.parser, -"""# MONADS!!!! + roundTrip( + Statement.parser, + """# MONADS!!!! struct Monad(pure: forall a. a -> f[a], flatMap: forall a, b. f[a] -> (a -> f[b]) -> f[b]) -""") +""" + ) // we can put new-lines in structs - roundTrip(Statement.parser, -"""# MONADS!!!! + roundTrip( + Statement.parser, + """# MONADS!!!! struct Monad( pure: forall a. a -> f[a], flatMap: forall a, b. f[a] -> (a -> f[b]) -> f[b]) -""") +""" + ) // we can put type params in - roundTrip(Statement.parser, -"""# MONADS!!!! + roundTrip( + Statement.parser, + """# MONADS!!!! struct Monad[f]( pure: forall a. a -> f[a], flatMap: forall a, b. f[a] -> (a -> f[b]) -> f[b]) -""") +""" + ) // we can put new-lines in defs - roundTrip(Statement.parser, -"""# + roundTrip( + Statement.parser, + """# def foo( x, y: Int): x.add(y) -""") +""" + ) roundTrip(Statement.parser, """enum Option: None, Some(a)""") roundTrip(Statement.parser, """enum Option[a]: None, Some(a: a)""") - roundTrip(Statement.parser, -"""enum Option: + roundTrip( + Statement.parser, + """enum Option: None - Some(a)""") + Some(a)""" + ) - roundTrip(Statement.parser, -"""enum Option[a]: + roundTrip( + Statement.parser, + """enum Option[a]: None - Some(a: a)""") - - roundTrip(Statement.parser, -"""enum Option: - None, Some(a)""") - - roundTripExact(Statement.parser, -"""def run(z): + Some(a: a)""" + ) + + roundTrip( + Statement.parser, + """enum Option: + None, Some(a)""" + ) + + roundTripExact( + Statement.parser, + """def run(z): Err(y) | Good(y) = z y -""") +""" + ) } def dropTrailingPadding(s: List[Statement]): List[Statement] = s.reverse.dropWhile { case Statement.PaddingStatement(_) => true - case _ => false + case _ => false }.reverse test("Any statement may append trailing whitespace and continue to parse") { @@ -1202,55 +1620,77 @@ def foo( } } - test("Any statement ending in a newline may have it removed and continue to parse") { + test( + "Any statement ending in a newline may have it removed and continue to parse" + ) { forAll(Generators.genStatement(5)) { s => val str = Document[Statement].document(s).render(80) - roundTrip(Statement.parser.map(dropTrailingPadding(_)), str.reverse.dropWhile(_ == '\n').reverse) + roundTrip( + Statement.parser.map(dropTrailingPadding(_)), + str.reverse.dropWhile(_ == '\n').reverse + ) } } - test("Any declaration may append any whitespace and optionally a comma and parse") { - forAll(Generators.genDeclaration(4), Gen.listOf(Gen.oneOf(' ', '\t')).map(_.mkString), Gen.oneOf(true, false)) { - case (s, ws, comma) => - val str = Document[Declaration].document(s).render(80) + ws + (if (comma) "," else "") - roundTrip(Declaration.parser(""), str, lax = true) + test( + "Any declaration may append any whitespace and optionally a comma and parse" + ) { + forAll( + Generators.genDeclaration(4), + Gen.listOf(Gen.oneOf(' ', '\t')).map(_.mkString), + Gen.oneOf(true, false) + ) { case (s, ws, comma) => + val str = + Document[Declaration].document(s).render(80) + ws + (if (comma) "," + else "") + roundTrip(Declaration.parser(""), str, lax = true) } } test("parse external defs") { - roundTrip(Statement.parser, -"""# header + roundTrip( + Statement.parser, + """# header external foo: String -""") - roundTrip(Statement.parser, -"""# header +""" + ) + roundTrip( + Statement.parser, + """# header external def foo(i: Integer) -> String -""") - roundTrip(Statement.parser, -"""# header +""" + ) + roundTrip( + Statement.parser, + """# header external def foo(i: Integer, b: a) -> String external def foo2(i: Integer, b: a) -> String -""") +""" + ) } - test("we can parse any package") { - roundTrip(Package.parser(None), -""" + roundTrip( + Package.parser(None), + """ package Foo/Bar from Baz import Bippy export foo foo = 1 -""") +""" + ) - val pp = Package.parser(None).map { pack => pack.copy(program = pack.program.map(_.replaceRegions(emptyRegion))) } + val pp = Package.parser(None).map { pack => + pack.copy(program = pack.program.map(_.replaceRegions(emptyRegion))) + } forAll(Generators.packageGen(4))(law(pp)) - roundTripExact(Package.parser(None), -"""package Foo + roundTripExact( + Package.parser(None), + """package Foo enum Res[a, b]: Err(a: a), Good(a: a, b: b) @@ -1261,104 +1701,141 @@ def run(z): y main = run(x) -""") +""" + ) } test("parse errors point near where they occur") { - expectFail(Statement.parser, + expectFail( + Statement.parser, """x = 1 z = 3 z = 4 y = {'x': 'x' : 'y'} -""", 32) +""", + 32 + ) - expectFail(Statement.parser, + expectFail( + Statement.parser, """x = 1 z = ( x = 1 x x) -""", 24) +""", + 24 + ) - expectFail(Statement.parser, + expectFail( + Statement.parser, """x = 1 z = ( x = 1 y = [1, 2, 3] x x) -""", 40) +""", + 40 + ) - expectFail(Statement.parser, + expectFail( + Statement.parser, """z = ( if f: 0 else 1) -""", 23) +""", + 23 + ) - expectFail(Package.parser(None), + expectFail( + Package.parser(None), """package Foo from Baz import a, , b x = 1 -""", 31) +""", + 31 + ) - expectFail(Package.parser(None), + expectFail( + Package.parser(None), """package Foo export x, , y x = 1 -""", 22) +""", + 22 + ) - expectFail(Package.parser(None), + expectFail( + Package.parser(None), """package Foo export x, , x = 1 -""", 22) - expectFail(Package.parser(None), +""", + 22 + ) + expectFail( + Package.parser(None), """package Foo x = Foo(bar if bar) -""", 31) +""", + 31 + ) - - expectFail(Package.parser(None), + expectFail( + Package.parser(None), """package Foo z = [x for x in xs if x < y else ] -""", 41) +""", + 41 + ) } test("using parens to make blocks") { - roundTrip(Package.parser(None), -"""package Foo + roundTrip( + Package.parser(None), + """package Foo x = ( y = 3 y ) -""", lax = true) +""", + lax = true + ) - roundTrip(Package.parser(None), -"""package Foo + roundTrip( + Package.parser(None), + """package Foo x = ( # some pattern matching Foo(y, _) = foo y ) -""", lax = true) +""", + lax = true + ) - roundTrip(Package.parser(None), -"""package Foo + roundTrip( + Package.parser(None), + """package Foo x = ( # an if/else block if True: 1 else: 0 ) -""", lax = true) +""", + lax = true + ) - roundTrip(Package.parser(None), -"""package Foo + roundTrip( + Package.parser(None), + """package Foo x = ( def foo(x): x @@ -1366,10 +1843,13 @@ x = ( foo(1) ) ) -""", lax = true) +""", + lax = true + ) - roundTrip(Package.parser(None), -"""package Foo + roundTrip( + Package.parser(None), + """package Foo x = ( # here is foo @@ -1379,32 +1859,43 @@ x = ( foo(1) ) ) -""", lax = true) +""", + lax = true + ) - roundTrip(Package.parser(None), -"""package Foo + roundTrip( + Package.parser(None), + """package Foo x = ( y = 3 y ) -""", lax = true) +""", + lax = true + ) } test("lambdas can have new lines") { - roundTrip(Package.parser(None), -"""package Foo + roundTrip( + Package.parser(None), + """package Foo x = z -> z -""", lax = true) +""", + lax = true + ) - roundTrip(Package.parser(None), -"""package Foo + roundTrip( + Package.parser(None), + """package Foo x = z -> # we can comment here z -""", lax = true) +""", + lax = true + ) } } diff --git a/core/src/test/scala/org/bykn/bosatsu/PatternTest.scala b/core/src/test/scala/org/bykn/bosatsu/PatternTest.scala index 5b9935de0..9eaf4b972 100644 --- a/core/src/test/scala/org/bykn/bosatsu/PatternTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/PatternTest.scala @@ -2,7 +2,10 @@ package org.bykn.bosatsu import cats.data.NonEmptyList import org.scalacheck.Gen -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{forAll, PropertyCheckConfiguration} +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.scalatest.funsuite.AnyFunSuite class PatternTest extends AnyFunSuite { @@ -23,7 +26,9 @@ class PatternTest extends AnyFunSuite { test("filtering for names not in a pattern is unbind") { forAll(patGen, Gen.listOf(Gen.identifier)) { (p, ids0) => val ids = ids0.map(Identifier.unsafe(_)) - assert(p.unbind == p.filterVars(ids.toSet.filterNot(p.names.toSet[Identifier]))) + assert( + p.unbind == p.filterVars(ids.toSet.filterNot(p.names.toSet[Identifier])) + ) } } @@ -60,27 +65,38 @@ class PatternTest extends AnyFunSuite { // we can name with the same name, and still be singly named assert(Pattern.SinglyNamed.unapply(Pattern.Named(n, p)) == Some(n)) // we can annotate and not lose singly named-ness - assert(Pattern.SinglyNamed.unapply(Pattern.Annotation(p, null)) == Some(n)) + assert( + Pattern.SinglyNamed.unapply(Pattern.Annotation(p, null)) == Some(n) + ) // we can make a union and not lose singly named-ness - assert(Pattern.SinglyNamed.unapply(Pattern.union(Pattern.Var(n), p :: Nil)) == Some(n)) + assert( + Pattern.SinglyNamed.unapply( + Pattern.union(Pattern.Var(n), p :: Nil) + ) == Some(n) + ) case _ => } forAll(patGen) { p => law(p) } - law(Pattern.Named(Identifier.Name("x"), Pattern.Named(Identifier.Name("x"), Pattern.WildCard))) + law( + Pattern.Named( + Identifier.Name("x"), + Pattern.Named(Identifier.Name("x"), Pattern.WildCard) + ) + ) } test("test some examples for singly named") { def check(str: String, nm: String) = pat(str) match { case Pattern.SinglyNamed(n) => assert(n == Identifier.unsafe(nm)) - case other => fail(s"expected singlynamed: $other") + case other => fail(s"expected singlynamed: $other") } def checkNot(str: String) = pat(str) match { case Pattern.SinglyNamed(n) => fail(s"unexpected singlynamed: $n") - case _ => succeed + case _ => succeed } check("foo", "foo") @@ -109,7 +125,12 @@ class PatternTest extends AnyFunSuite { val bar = Identifier.Name("bar") assert(Pattern.Var(foo).substructures.isEmpty) assert(Pattern.Annotation(Pattern.Var(foo), "Type").substructures.isEmpty) - assert(Pattern.Union(Pattern.Var(foo), NonEmptyList.of(Pattern.Var(bar))).substructures.isEmpty) + assert( + Pattern + .Union(Pattern.Var(foo), NonEmptyList.of(Pattern.Var(bar))) + .substructures + .isEmpty + ) } test("unions with total matches work correctly") { diff --git a/core/src/test/scala/org/bykn/bosatsu/SelfCallKindTest.scala b/core/src/test/scala/org/bykn/bosatsu/SelfCallKindTest.scala index 783aaf24e..3cb3ae82f 100644 --- a/core/src/test/scala/org/bykn/bosatsu/SelfCallKindTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/SelfCallKindTest.scala @@ -2,14 +2,17 @@ package org.bykn.bosatsu import org.scalacheck.Gen import org.scalatest.funsuite.AnyFunSuite -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.bykn.bosatsu.rankn.NTypeGen import org.bykn.bosatsu.TestUtils.checkLast import org.bykn.bosatsu.Identifier.Name class SelfCallKindTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 5000) + // PropertyCheckConfiguration(minSuccessful = 5000) PropertyCheckConfiguration(minSuccessful = 500) def gen[A](g: Gen[A]): Gen[TypedExpr[A]] = @@ -21,8 +24,7 @@ class SelfCallKindTest extends AnyFunSuite { test("test selfCallKind") { import SelfCallKind.{NoCall, NonTailCall, TailCall, apply => selfCallKind} - checkLast( - """ + checkLast(""" enum List[a]: E, NE(head: a, tail: List[a]) enum N: Z, S(prev: N) @@ -32,8 +34,7 @@ def list_len(list, acc): case NE(_, t): list_len(t, S(acc)) """) { te => assert(selfCallKind(Name("list_len"), te) == TailCall) } - checkLast( - """ + checkLast(""" enum List[a]: E, NE(head: a, tail: List[a]) enum N: Z, S(prev: N) @@ -43,8 +44,7 @@ def list_len(list): case NE(_, t): S(list_len(t)) """) { te => assert(selfCallKind(Name("list_len"), te) == NonTailCall) } - checkLast( - """ + checkLast(""" enum List[a]: E, NE(head: a, tail: List[a]) def list_len(list): @@ -85,4 +85,4 @@ def list_len(list): } } -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/bykn/bosatsu/SourceConverterTest.scala b/core/src/test/scala/org/bykn/bosatsu/SourceConverterTest.scala index 865c84e36..b7529e91f 100644 --- a/core/src/test/scala/org/bykn/bosatsu/SourceConverterTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/SourceConverterTest.scala @@ -1,7 +1,10 @@ package org.bykn.bosatsu import org.scalacheck.Gen -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{forAll, PropertyCheckConfiguration} +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import Identifier.Bindable @@ -10,14 +13,19 @@ import org.scalatest.funsuite.AnyFunSuite class SourceConverterTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - PropertyCheckConfiguration(minSuccessful = if (Platform.isScalaJvm) 3000 else 20) + PropertyCheckConfiguration(minSuccessful = + if (Platform.isScalaJvm) 3000 else 20 + ) val genRec = Gen.oneOf(RecursionKind.NonRecursive, RecursionKind.Recursive) test("makeLetsUnique preserves let count") { val genLets = for { cnt <- Gen.choose(0, 100) - lets <- Gen.listOfN(cnt, Gen.zip(Generators.bindIdentGen, genRec, Gen.const(()))) + lets <- Gen.listOfN( + cnt, + Gen.zip(Generators.bindIdentGen, genRec, Gen.const(())) + ) } yield lets forAll(genLets) { lets => @@ -41,7 +49,8 @@ class SourceConverterTest extends AnyFunSuite { names <- Gen.listOfN(cnt, Generators.bindIdentGen) namesDistinct = names.distinct lets <- Generators.traverseGen(namesDistinct) { nm => - Gen.zip(genRec, Gen.choose(0, 10)) + Gen + .zip(genRec, Gen.choose(0, 10)) .map { case (r, d) => (nm, r, d) } } } yield lets @@ -58,7 +67,10 @@ class SourceConverterTest extends AnyFunSuite { test("makeLetsUnique applies to rhs for recursive binds") { val genLets = for { cnt <- Gen.choose(0, 100) - lets <- Gen.listOfN(cnt, Generators.bindIdentGen.map { b => (b, RecursionKind.Recursive, b) }) + lets <- Gen.listOfN( + cnt, + Generators.bindIdentGen.map { b => (b, RecursionKind.Recursive, b) } + ) } yield lets forAll(genLets) { lets => @@ -77,12 +89,33 @@ class SourceConverterTest extends AnyFunSuite { { // non recursive val l1 = List( - (Identifier.Name("b"), RecursionKind.NonRecursive, Option.empty[String]), - (Identifier.Name("a"), RecursionKind.NonRecursive, Option.empty[String]), - (Identifier.Name("c"), RecursionKind.NonRecursive, Option.empty[String]), - (Identifier.Name("a"), RecursionKind.NonRecursive, Option.empty[String]), - (Identifier.Name("d"), RecursionKind.NonRecursive, Option.empty[String]), - (Identifier.Name("a"), RecursionKind.NonRecursive, Option.empty[String])) + ( + Identifier.Name("b"), + RecursionKind.NonRecursive, + Option.empty[String] + ), + ( + Identifier.Name("a"), + RecursionKind.NonRecursive, + Option.empty[String] + ), + ( + Identifier.Name("c"), + RecursionKind.NonRecursive, + Option.empty[String] + ), + ( + Identifier.Name("a"), + RecursionKind.NonRecursive, + Option.empty[String] + ), + ( + Identifier.Name("d"), + RecursionKind.NonRecursive, + Option.empty[String] + ), + (Identifier.Name("a"), RecursionKind.NonRecursive, Option.empty[String]) + ) val up1 = SourceConverter.makeLetsUnique(l1) { case (Identifier.Name(n), idx) => @@ -99,7 +132,8 @@ class SourceConverterTest extends AnyFunSuite { (Identifier.Name("c"), RecursionKind.NonRecursive, Some("a0")), (Identifier.Name("a1"), RecursionKind.NonRecursive, Some("a0")), (Identifier.Name("d"), RecursionKind.NonRecursive, Some("a1")), - (Identifier.Name("a"), RecursionKind.NonRecursive, Some("a1"))) + (Identifier.Name("a"), RecursionKind.NonRecursive, Some("a1")) + ) assert(up1 == expectl1) } @@ -111,7 +145,8 @@ class SourceConverterTest extends AnyFunSuite { (Identifier.Name("c"), RecursionKind.Recursive, Option.empty[String]), (Identifier.Name("a"), RecursionKind.Recursive, Option.empty[String]), (Identifier.Name("d"), RecursionKind.Recursive, Option.empty[String]), - (Identifier.Name("a"), RecursionKind.Recursive, Option.empty[String])) + (Identifier.Name("a"), RecursionKind.Recursive, Option.empty[String]) + ) val up1 = SourceConverter.makeLetsUnique(l1) { case (Identifier.Name(n), idx) => @@ -128,7 +163,8 @@ class SourceConverterTest extends AnyFunSuite { (Identifier.Name("c"), RecursionKind.Recursive, Some("a0")), (Identifier.Name("a1"), RecursionKind.Recursive, Some("a1")), (Identifier.Name("d"), RecursionKind.Recursive, Some("a1")), - (Identifier.Name("a"), RecursionKind.Recursive, None)) + (Identifier.Name("a"), RecursionKind.Recursive, None) + ) assert(up1 == expectl1) } } diff --git a/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala b/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala index 575b3c44b..9f27f77c4 100644 --- a/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala +++ b/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala @@ -10,13 +10,16 @@ import IorMethods.IorExtension object TestUtils { - def parsedTypeEnvOf(pack: PackageName, str: String): ParsedTypeEnv[Option[Kind.Arg]] = { + def parsedTypeEnvOf( + pack: PackageName, + str: String + ): ParsedTypeEnv[Option[Kind.Arg]] = { val stmt = statementsOf(str) val prog = SourceConverter.toProgram(pack, Nil, stmt) match { - case Ior.Right(prog) => prog + case Ior.Right(prog) => prog case Ior.Both(_, prog) => prog - case Ior.Left(err) => sys.error(err.toString) + case Ior.Left(err) => sys.error(err.toString) } prog.types._2 } @@ -24,9 +27,9 @@ object TestUtils { val predefParsedTypeEnv: ParsedTypeEnv[Option[Kind.Arg]] = { val p = Package.predefPackage val prog = SourceConverter.toProgram(p.name, Nil, p.program) match { - case Ior.Right(prog) => prog + case Ior.Right(prog) => prog case Ior.Both(_, prog) => prog - case Ior.Left(err) => sys.error(err.toString) + case Ior.Left(err) => sys.error(err.toString) } prog.types._2 } @@ -37,16 +40,15 @@ object TestUtils { def statementsOf(str: String): List[Statement] = Parser.unsafeParse(Statement.parser, str) - /** - * Make sure no illegal final types escaped into a TypedExpr - */ + /** Make sure no illegal final types escaped into a TypedExpr + */ def assertValid[A](te: TypedExpr[A]): Unit = { def checkType(t: Type): Type = t match { - case t@Type.TyVar(Type.Var.Skolem(_, _, _)) => + case t @ Type.TyVar(Type.Var.Skolem(_, _, _)) => sys.error(s"illegal skolem ($t) escape in ${te.repr}") case Type.TyVar(Type.Var.Bound(_)) => t - case t@Type.TyMeta(_) => + case t @ Type.TyMeta(_) => sys.error(s"illegal meta ($t) escape in ${te.repr}") case Type.TyApply(left, right) => Type.TyApply(checkType(left), checkType(right)) @@ -57,24 +59,32 @@ object TestUtils { te.traverseType[cats.Id](checkType) val tp = te.getType lazy val teStr = Type.fullyResolvedDocument.document(tp).render(80) - scala.Predef.require(Type.freeTyVars(tp :: Nil).isEmpty, - s"illegal inferred type: $teStr in: ${te.repr}") - - scala.Predef.require(Type.metaTvs(tp :: Nil).isEmpty, - s"illegal inferred type: $teStr in: ${te.repr}") + scala.Predef.require( + Type.freeTyVars(tp :: Nil).isEmpty, + s"illegal inferred type: $teStr in: ${te.repr}" + ) + + scala.Predef.require( + Type.metaTvs(tp :: Nil).isEmpty, + s"illegal inferred type: $teStr in: ${te.repr}" + ) } val testPackage: PackageName = PackageName.parts("Test") - def checkLast(statement: String)(fn: TypedExpr[Declaration] => Assertion): Assertion = { + def checkLast( + statement: String + )(fn: TypedExpr[Declaration] => Assertion): Assertion = { val stmts = Parser.unsafeParse(Statement.parser, statement) Package.inferBody(testPackage, Nil, stmts).strictToValidated match { case Validated.Invalid(errs) => val lm = LocationMap(statement) val packMap = Map((testPackage, (lm, statement))) - val msg = errs.toList.map { err => - err.message(packMap, LocationMap.Colorize.None) - }.mkString("\n==========\n") + val msg = errs.toList + .map { err => + err.message(packMap, LocationMap.Colorize.None) + } + .mkString("\n==========\n") fail("inference failure: " + msg) case Validated.Valid(program) => // make sure all the TypedExpr are valid @@ -84,7 +94,9 @@ object TestUtils { } def makeInputArgs(files: List[(Int, Any)]): List[String] = - ("--package_root" :: Int.MaxValue.toString :: Nil) ::: files.flatMap { case (idx, _) => "--input" :: idx.toString :: Nil } + ("--package_root" :: Int.MaxValue.toString :: Nil) ::: files.flatMap { + case (idx, _) => "--input" :: idx.toString :: Nil + } private val module = new MemoryMain[Either[Throwable, *], Int]({ idx => if (idx == Int.MaxValue) Nil @@ -94,24 +106,37 @@ object TestUtils { def evalTest(packages: List[String], mainPackS: String, expected: Value) = { val files = packages.zipWithIndex.map(_.swap) - module.runWith(files)("eval" :: "--main" :: mainPackS :: makeInputArgs(files)) match { + module.runWith(files)( + "eval" :: "--main" :: mainPackS :: makeInputArgs(files) + ) match { case Right(module.Output.EvaluationResult(got, _, gotDoc)) => val gv = got.value - assert(gv == expected, s"${gotDoc.value.render(80)}\n\n$gv != $expected") + assert( + gv == expected, + s"${gotDoc.value.render(80)}\n\n$gv != $expected" + ) case Right(other) => fail(s"got an unexpected success: $other") case Left(err) => module.mainExceptionToString(err) match { case Some(msg) => fail(msg) - case None => fail(s"got an exception: $err") + case None => fail(s"got an exception: $err") } } } - def evalTestJson(packages: List[String], mainPackS: String, expected: Json) = { + def evalTestJson( + packages: List[String], + mainPackS: String, + expected: Json + ) = { val files = packages.zipWithIndex.map(_.swap) - module.runWith(files)("json" :: "write" :: "--main" :: mainPackS :: "--output" :: "-1" :: makeInputArgs(files)) match { + module.runWith(files)( + "json" :: "write" :: "--main" :: mainPackS :: "--output" :: "-1" :: makeInputArgs( + files + ) + ) match { case Right(module.Output.JsonOutput(got, _)) => assert(got == expected, s"$got != $expected") case Right(other) => @@ -121,15 +146,25 @@ object TestUtils { } } - def runBosatsuTest(packages: List[String], mainPackS: String, assertionCount: Int) = { + def runBosatsuTest( + packages: List[String], + mainPackS: String, + assertionCount: Int + ) = { val files = packages.zipWithIndex.map(_.swap) - module.runWith(files)("test" :: "--test_package" :: mainPackS :: makeInputArgs(files)) match { + module.runWith(files)( + "test" :: "--test_package" :: mainPackS :: makeInputArgs(files) + ) match { case Right(module.Output.TestOutput(results, _)) => results.collect { case (_, Some(t)) => t.value } match { case t :: Nil => - assert(t.assertions == assertionCount, s"${t.assertions} != $assertionCount") - val (_, failcount, message) = Test.report(t, LocationMap.Colorize.None) + assert( + t.assertions == assertionCount, + s"${t.assertions} != $assertionCount" + ) + val (_, failcount, message) = + Test.report(t, LocationMap.Colorize.None) assert(t.failures.map(_.assertions).getOrElse(0) == failcount) if (failcount > 0) fail(message.render(80)) else succeed @@ -146,44 +181,54 @@ object TestUtils { } } - def testInferred(packages: List[String], mainPackS: String, inferredHandler: (PackageMap.Inferred, PackageName) => Assertion)(implicit ec: Par.EC) = { + def testInferred( + packages: List[String], + mainPackS: String, + inferredHandler: (PackageMap.Inferred, PackageName) => Assertion + )(implicit ec: Par.EC) = { val mainPack = PackageName.parse(mainPackS).get val parsed = packages.zipWithIndex.traverse { case (pack, i) => Parser.parse(Package.parser(None), pack).map { case (lm, parsed) => ((i.toString, lm), parsed) } - } + } val parsedPaths = parsed match { case Validated.Valid(vs) => vs case Validated.Invalid(errs) => errs.toList.foreach { p => - System.err.println(p.showContext(LocationMap.Colorize.None).render(80)) + System.err.println( + p.showContext(LocationMap.Colorize.None).render(80) + ) } - sys.error("failed to parse") //errs.toString) + sys.error("failed to parse") // errs.toString) } val fullParsed = - PackageMap.withPredefA(("predef", LocationMap("")), parsedPaths) - .map { case ((path, _), p) => (path, p) } + PackageMap + .withPredefA(("predef", LocationMap("")), parsedPaths) + .map { case ((path, _), p) => (path, p) } PackageMap - .resolveThenInfer(fullParsed , Nil).strictToValidated match { - case Validated.Valid(packMap) => - inferredHandler(packMap, mainPack) - - case Validated.Invalid(errs) => - val tes = errs.toList.collect { - case PackageError.TypeErrorIn(te, _) => - te.toString + .resolveThenInfer(fullParsed, Nil) + .strictToValidated match { + case Validated.Valid(packMap) => + inferredHandler(packMap, mainPack) + + case Validated.Invalid(errs) => + val tes = errs.toList + .collect { case PackageError.TypeErrorIn(te, _) => + te.toString } .mkString("\n") - fail(tes + "\n" + errs.toString) - } + fail(tes + "\n" + errs.toString) + } } - def evalFail(packages: List[String])(errFn: PartialFunction[PackageError, Unit])(implicit ec: Par.EC) = { + def evalFail( + packages: List[String] + )(errFn: PartialFunction[PackageError, Unit])(implicit ec: Par.EC) = { val parsed = packages.zipWithIndex.traverse { case (pack, i) => Parser.parse(Package.parser(None), pack).map { case (lm, parsed) => @@ -212,9 +257,8 @@ object TestUtils { errs.toList.foreach(_.message(sm, LocationMap.Colorize.None)) assert(true) case Some(errs) => - fail(s"failed, but no type errors: $errs") + fail(s"failed, but no type errors: $errs") } } - } diff --git a/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala b/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala index 55e47bd34..e33891a5e 100644 --- a/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/TotalityTest.scala @@ -2,7 +2,10 @@ package org.bykn.bosatsu import cats.Eq import cats.data.NonEmptyList -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.scalacheck.Gen import org.bykn.bosatsu.pattern.{SetOps, SetOpsLaws} @@ -17,13 +20,16 @@ import Identifier.Constructor import cats.implicits._ -class TotalityTest extends SetOpsLaws[Pattern[(PackageName, Constructor), Type]] { +class TotalityTest + extends SetOpsLaws[Pattern[(PackageName, Constructor), Type]] { type Pat = Pattern[(PackageName, Constructor), Type] implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 50000) - PropertyCheckConfiguration(minSuccessful = if (Platform.isScalaJvm) 50000 else 100) - //PropertyCheckConfiguration(minSuccessful = 50) + // PropertyCheckConfiguration(minSuccessful = 50000) + PropertyCheckConfiguration(minSuccessful = + if (Platform.isScalaJvm) 50000 else 100 + ) + // PropertyCheckConfiguration(minSuccessful = 50) val genPattern: Gen[Pattern[(PackageName, Constructor), Type]] = Generators.genCompiledPattern(5, useAnnotation = false) @@ -32,9 +38,8 @@ class TotalityTest extends SetOpsLaws[Pattern[(PackageName, Constructor), Type]] Generators.genCompiledPattern(5, useUnion = false, useAnnotation = false) def showPat(pat: Pattern[(PackageName, Constructor), Type]): String = { - val pat0 = pat.mapName { - case (_, n) => - Pattern.StructKind.Named(n, Pattern.StructKind.Style.TupleLike) + val pat0 = pat.mapName { case (_, n) => + Pattern.StructKind.Named(n, Pattern.StructKind.Style.TupleLike) } implicit val tdoc = Type.fullyResolvedDocument @@ -70,62 +75,76 @@ enum Bool: False, True new Eq[List[Pattern[(PackageName, Constructor), Type]]] { val e1 = TotalityCheck(predefTE).eqPat - def eqv(a: List[Pattern[(PackageName, Constructor), Type]], - b: List[Pattern[(PackageName, Constructor), Type]]) = - (NonEmptyList.fromList(a), NonEmptyList.fromList(b)) match { - case (oa, ob) if oa == ob => true - case (Some(a), Some(b)) => - e1.eqv(Pattern.union(a.head, a.tail), Pattern.union(b.head, b.tail)) - case _ => false - } + def eqv( + a: List[Pattern[(PackageName, Constructor), Type]], + b: List[Pattern[(PackageName, Constructor), Type]] + ) = + (NonEmptyList.fromList(a), NonEmptyList.fromList(b)) match { + case (oa, ob) if oa == ob => true + case (Some(a), Some(b)) => + e1.eqv(Pattern.union(a.head, a.tail), Pattern.union(b.head, b.tail)) + case _ => false + } } def eqUnion: Gen[Eq[List[Pattern[(PackageName, Constructor), Type]]]] = Gen.const(eqPatterns) def patterns(str: String): List[Pattern[(PackageName, Constructor), Type]] = { - val nameToCons: Constructor => (PackageName, Constructor) = - { cons => (PackageName.PredefName, cons) } - - /** - * This is sufficient for these tests, but is not - * a full features pattern compiler. - */ - def parsedToExpr(pat: Pattern.Parsed): Pattern[(PackageName, Constructor), rankn.Type] = - pat.mapStruct[(PackageName, Constructor)] { - case (Pattern.StructKind.Tuple, args) => - // this is a tuple pattern - def loop(args: List[Pattern[(PackageName, Constructor), TypeRef]]): Pattern[(PackageName, Constructor), TypeRef] = - args match { - case Nil => - // () - Pattern.PositionalStruct( - (PackageName.PredefName, Constructor("Unit")), - Nil) - case h :: tail => - val tailP = loop(tail) - Pattern.PositionalStruct( - (PackageName.PredefName, Constructor("TupleCons")), - h :: tailP :: Nil) - } - - loop(args) - case (Pattern.StructKind.Named(nm, _), args) => - Pattern.PositionalStruct(nameToCons(nm), args) - case (Pattern.StructKind.NamedPartial(nm, _), args) => - Pattern.PositionalStruct(nameToCons(nm), args) - } - .mapType { tref => - TypeRefConverter[cats.Id](tref) { tpe => - Type.Const.Defined(PackageName.PredefName, TypeName(tpe)) + val nameToCons: Constructor => (PackageName, Constructor) = { cons => + (PackageName.PredefName, cons) + } + + /** This is sufficient for these tests, but is not a full features pattern + * compiler. + */ + def parsedToExpr( + pat: Pattern.Parsed + ): Pattern[(PackageName, Constructor), rankn.Type] = + pat + .mapStruct[(PackageName, Constructor)] { + case (Pattern.StructKind.Tuple, args) => + // this is a tuple pattern + def loop( + args: List[Pattern[(PackageName, Constructor), TypeRef]] + ): Pattern[(PackageName, Constructor), TypeRef] = + args match { + case Nil => + // () + Pattern.PositionalStruct( + (PackageName.PredefName, Constructor("Unit")), + Nil + ) + case h :: tail => + val tailP = loop(tail) + Pattern.PositionalStruct( + (PackageName.PredefName, Constructor("TupleCons")), + h :: tailP :: Nil + ) + } + + loop(args) + case (Pattern.StructKind.Named(nm, _), args) => + Pattern.PositionalStruct(nameToCons(nm), args) + case (Pattern.StructKind.NamedPartial(nm, _), args) => + Pattern.PositionalStruct(nameToCons(nm), args) + } + .mapType { tref => + TypeRefConverter[cats.Id](tref) { tpe => + Type.Const.Defined(PackageName.PredefName, TypeName(tpe)) + } } - } - Parser.unsafeParse(Pattern.matchParser.listSyntax, str) + Parser + .unsafeParse(Pattern.matchParser.listSyntax, str) .map(parsedToExpr _) } - def notTotal(te: TypeEnv[Any], pats: List[Pattern[(PackageName, Constructor), Type]], testMissing: Boolean = true): Unit = { + def notTotal( + te: TypeEnv[Any], + pats: List[Pattern[(PackageName, Constructor), Type]], + testMissing: Boolean = true + ): Unit = { val res = TotalityCheck(te).isTotal(pats) assert(!res, pats.toString) @@ -143,7 +162,11 @@ enum Bool: False, True } } - def testTotality(te: TypeEnv[Any], pats: List[Pattern[(PackageName, Constructor), Type]], tight: Boolean = false) = { + def testTotality( + te: TypeEnv[Any], + pats: List[Pattern[(PackageName, Constructor), Type]], + tight: Boolean = false + ) = { val res = TotalityCheck(te).missingBranches(pats) val asStr = res.map(showPat) assert(asStr == Nil, showPats(pats)) @@ -151,7 +174,7 @@ enum Bool: False, True // any missing pattern shouldn't be total: def allButOne[A](head: A, tail: List[A]): List[List[A]] = tail match { - case Nil => Nil + case Nil => Nil case h :: rest => // we can either delete the head or one from the tail: val keepHead = allButOne(h, rest).map(head :: _) @@ -160,7 +183,9 @@ enum Bool: False, True pats match { case h :: tail if tight => - allButOne(h, tail).foreach(notTotal(te, _, testMissing = false)) // don't make an infinite loop here + allButOne(h, tail).foreach( + notTotal(te, _, testMissing = false) + ) // don't make an infinite loop here case _ => () } } @@ -178,7 +203,6 @@ struct Unit val pats = patterns("[Unit]") testTotality(te, pats) - val te1 = typeEnvOf("""# struct TupleCons(a, b) """) @@ -195,7 +219,11 @@ enum Option: None, Some(get) testTotality(te, patterns("[Some(_) | None]"), tight = true) testTotality(te, patterns("[Some(_), _]")) testTotality(te, patterns("[Some(1), Some(x), None]")) - testTotality(te, patterns("[Some(Some(_)), Some(None), None]"), tight = true) + testTotality( + te, + patterns("[Some(Some(_)), Some(None), None]"), + tight = true + ) testTotality(te, patterns("[Some(Some(_) | None), None]"), tight = true) notTotal(te, patterns("[Some(_)]")) @@ -210,13 +238,19 @@ enum Option: None, Some(get) enum Either: Left(l), Right(r) """) testTotality(te, patterns("[Left(_), Right(_)]")) - testTotality(te, - patterns("[Left(Right(_)), Left(Left(_)), Right(Left(_)), Right(Right(_))]"), - tight = true) - - testTotality(te, + testTotality( + te, + patterns( + "[Left(Right(_)), Left(Left(_)), Right(Left(_)), Right(Right(_))]" + ), + tight = true + ) + + testTotality( + te, patterns("[Left(Right(_) | Left(_)), Right(Left(_) | Right(_))]"), - tight = true) + tight = true + ) notTotal(te, patterns("[Left(_)]")) notTotal(te, patterns("[Right(_)]")) @@ -226,10 +260,22 @@ enum Either: Left(l), Right(r) test("test List matching") { testTotality(predefTE, patterns("[[], [h, *tail]]"), tight = true) - testTotality(predefTE, patterns("[[], [h, *tail], [h0, h1, *tail]]"), tight = true) + testTotality( + predefTE, + patterns("[[], [h, *tail], [h0, h1, *tail]]"), + tight = true + ) testTotality(predefTE, patterns("[[], [*tail, _]]"), tight = true) - testTotality(predefTE, patterns("[[*_, True, *_], [], [False, *_]]"), tight = true) - testTotality(predefTE, patterns("[[*_, True, *_], [] | [False, *_]]"), tight = true) + testTotality( + predefTE, + patterns("[[*_, True, *_], [], [False, *_]]"), + tight = true + ) + testTotality( + predefTE, + patterns("[[*_, True, *_], [] | [False, *_]]"), + tight = true + ) notTotal(predefTE, patterns("[[], [h, *tail, _]]")) } @@ -241,33 +287,56 @@ enum Option: None, Some(get) struct TupleCons(fst, snd) """) - testTotality(te, patterns("[None, Some(Left(_)), Some(Right(_))]"), tight = true) + testTotality( + te, + patterns("[None, Some(Left(_)), Some(Right(_))]"), + tight = true + ) testTotality(te, patterns("[None, Some(Left(_) | Right(_))]"), tight = true) - testTotality(te, patterns("[None, Some(TupleCons(Left(_), _)), Some(TupleCons(_, Right(_))), Some(TupleCons(Right(_), Left(_)))]"), tight = true) - testTotality(te, patterns("[None, Some(TupleCons(Left(_), _) | TupleCons(_, Right(_))), Some(TupleCons(Right(_), Left(_)))]"), tight = true) + testTotality( + te, + patterns( + "[None, Some(TupleCons(Left(_), _)), Some(TupleCons(_, Right(_))), Some(TupleCons(Right(_), Left(_)))]" + ), + tight = true + ) + testTotality( + te, + patterns( + "[None, Some(TupleCons(Left(_), _) | TupleCons(_, Right(_))), Some(TupleCons(Right(_), Left(_)))]" + ), + tight = true + ) } test("compose List with structs") { val te = typeEnvOf("""# enum Either: Left(l), Right(r) """) - testTotality(te, patterns("[[Left(_), *_], [Right(_), *_], [], [_, _, *_]]"), tight = true) - testTotality(te, patterns("[Left([]), Left([h, *_]), Right([]), Right([h, *_])]"), tight = true) + testTotality( + te, + patterns("[[Left(_), *_], [Right(_), *_], [], [_, _, *_]]"), + tight = true + ) + testTotality( + te, + patterns("[Left([]), Left([h, *_]), Right([]), Right([h, *_])]"), + tight = true + ) } - test("test intersection") { val p0 :: p1 :: p1norm :: Nil = patterns("[[*_], [*_, _], [_, *_]]") - TotalityCheck(predefTE).intersection(p0, p1) match { - case List(intr) => assert(intr == p1norm) - case other => fail(s"expected exactly one intersection: $other") - } + TotalityCheck(predefTE).intersection(p0, p1) match { + case List(intr) => assert(intr == p1norm) + case other => fail(s"expected exactly one intersection: $other") + } val p2 :: p3 :: Nil = patterns("[[*_], [_, _]]") - TotalityCheck(predefTE).intersection(p2, p3) match { - case List(intr) => assert(p3 == intr) - case other => fail(s"expected exactly one intersection: $other") - } + TotalityCheck(predefTE).intersection(p2, p3) match { + case List(intr) => assert(p3 == intr) + case other => fail(s"expected exactly one intersection: $other") + } } test("test some difference examples") { @@ -297,7 +366,7 @@ enum Either: Left(l), Right(r) val p0 :: p1 :: Nil = patterns("[[*_, _], [_, *_]]") TotalityCheck(predefTE).intersection(p0, p1) match { case List(res) if res == p0 || res == p1 => succeed - case Nil => fail("these do overlap") + case Nil => fail("these do overlap") case nonUnified => fail(s"didn't unify to one: $nonUnified") } } @@ -310,13 +379,52 @@ enum Either: Left(l), Right(r) import Identifier.Name val regressions: List[(Pat, Pat, Pat)] = - (Named(Name("hTt"), StrPat(NonEmptyList.of(NamedStr(Name("rfb")), LitStr("q"), NamedStr(Name("ngkrx"))))), + ( + Named( + Name("hTt"), + StrPat( + NonEmptyList + .of(NamedStr(Name("rfb")), LitStr("q"), NamedStr(Name("ngkrx"))) + ) + ), WildCard, - Named(Name("hjbmtklh"),StrPat(NonEmptyList.of(NamedStr(Name("qz8lcT")), WildStr, LitStr("p7"), NamedStr(Name("hqxprG")))))) :: - (WildCard, - ListPat(List(NamedList(Name("nv6")), Item(Literal(Lit.fromInt(-17))), Item(WildCard))), - ListPat(List(Item(StrPat(NonEmptyList.of(WildStr))), Item(StrPat(NonEmptyList.of(NamedStr(Name("eejhh")), LitStr("jbuzfcwsumP"), WildStr)))))) :: - Nil + Named( + Name("hjbmtklh"), + StrPat( + NonEmptyList.of( + NamedStr(Name("qz8lcT")), + WildStr, + LitStr("p7"), + NamedStr(Name("hqxprG")) + ) + ) + ) + ) :: + ( + WildCard, + ListPat( + List( + NamedList(Name("nv6")), + Item(Literal(Lit.fromInt(-17))), + Item(WildCard) + ) + ), + ListPat( + List( + Item(StrPat(NonEmptyList.of(WildStr))), + Item( + StrPat( + NonEmptyList.of( + NamedStr(Name("eejhh")), + LitStr("jbuzfcwsumP"), + WildStr + ) + ) + ) + ) + ) + ) :: + Nil regressions.foreach { case (a, b, c) => diffIntersectionLaw(a, b, c) @@ -330,23 +438,27 @@ enum Either: Left(l), Right(r) } val regressions: List[(Pat, Pat)] = - List( - { - val struct = Pattern.PositionalStruct((PackageName(NonEmptyList.of("Pack")), Identifier.Constructor("Foo")), Nil) - val lst = Pattern.ListPat(List(Pattern.ListPart.WildList)) - (struct, lst) - }) + List({ + val struct = Pattern.PositionalStruct( + (PackageName(NonEmptyList.of("Pack")), Identifier.Constructor("Foo")), + Nil + ) + val lst = Pattern.ListPat(List(Pattern.ListPart.WildList)) + (struct, lst) + }) regressions.foreach { case (a, b) => law(a, b) } } test("subset consistency regressions") { - val regressions: List[(Pat, Pat)] = - { - val struct = Pattern.PositionalStruct((PackageName(NonEmptyList.of("Pack")), Identifier.Constructor("Foo")), Nil) - val lst = Pattern.ListPat(List(Pattern.ListPart.WildList)) - (struct, lst) - } :: + val regressions: List[(Pat, Pat)] = { + val struct = Pattern.PositionalStruct( + (PackageName(NonEmptyList.of("Pack")), Identifier.Constructor("Foo")), + Nil + ) + val lst = Pattern.ListPat(List(Pattern.ListPart.WildList)) + (struct, lst) + } :: Nil regressions.foreach { case (a, b) => @@ -361,14 +473,23 @@ enum Either: Left(l), Right(r) import StrPart.WildStr val regressions: List[(Pat, Pat)] = - List( - { - val left = ListPat(List(Item(WildCard), WildList)) - val right = ListPat(List(Item(Var(Name("bey6ct"))), Item(Literal(Lit.fromInt(42))), Item(StrPat(NonEmptyList.of(WildStr))), Item(Literal(Lit("agfn"))), Item(WildCard))) - (left, right) - }) + List({ + val left = ListPat(List(Item(WildCard), WildList)) + val right = ListPat( + List( + Item(Var(Name("bey6ct"))), + Item(Literal(Lit.fromInt(42))), + Item(StrPat(NonEmptyList.of(WildStr))), + Item(Literal(Lit("agfn"))), + Item(WildCard) + ) + ) + (left, right) + }) - regressions.foreach { case (a, b) => differenceIsIdempotent(a, b, eqPatterns) } + regressions.foreach { case (a, b) => + differenceIsIdempotent(a, b, eqPatterns) + } } test("if a n b = 0 then a - b = a regressions") { @@ -382,16 +503,27 @@ enum Either: Left(l), Right(r) List( { val left = ListPat(List(Item(WildCard), WildList)) - val right = ListPat(List(Item(Var(Name("bey6ct"))), Item(Literal(Lit.fromInt(42))), Item(StrPat(NonEmptyList.of(WildStr))), Item(Literal(Lit("agfn"))), Item(WildCard))) + val right = ListPat( + List( + Item(Var(Name("bey6ct"))), + Item(Literal(Lit.fromInt(42))), + Item(StrPat(NonEmptyList.of(WildStr))), + Item(Literal(Lit("agfn"))), + Item(WildCard) + ) + ) (left, right) - }, - { - val left = ListPat(List(NamedList(Name("a")), Item(WildCard), Item(Var(Name("b"))))) + }, { + val left = ListPat( + List(NamedList(Name("a")), Item(WildCard), Item(Var(Name("b")))) + ) val right = ListPat(List()) (left, right) } ) - regressions.foreach { case (a, b) => emptyIntersectionMeansDiffIdent(a, b, eqPatterns) } + regressions.foreach { case (a, b) => + emptyIntersectionMeansDiffIdent(a, b, eqPatterns) + } } } diff --git a/core/src/test/scala/org/bykn/bosatsu/TypeRefTest.scala b/core/src/test/scala/org/bykn/bosatsu/TypeRefTest.scala index c96ea279d..acca26418 100644 --- a/core/src/test/scala/org/bykn/bosatsu/TypeRefTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/TypeRefTest.scala @@ -1,12 +1,15 @@ package org.bykn.bosatsu -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.bykn.bosatsu.rankn.Type import org.scalatest.funsuite.AnyFunSuite class TypeRefTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 500000) + // PropertyCheckConfiguration(minSuccessful = 500000) PropertyCheckConfiguration(minSuccessful = 5000) import Generators.{typeRefGen, shrinkTypeRef} @@ -21,14 +24,18 @@ class TypeRefTest extends AnyFunSuite { val pn = PackageName.parts("Test") forAll(typeRefGen) { tr => - val tpe = TypeRefConverter[cats.Id](tr) { c => Type.Const.Defined(pn, TypeName(c)) } - val tr1 = TypeRefConverter.fromTypeA[Option](tpe, + val tpe = TypeRefConverter[cats.Id](tr) { c => + Type.Const.Defined(pn, TypeName(c)) + } + val tr1 = TypeRefConverter.fromTypeA[Option]( + tpe, { _ => None }, { _ => None }, { case Type.Const.Defined(p, t) if p == pn => Some(TypeRef.TypeName(t)) - case _ => None - }) + case _ => None + } + ) assert(tr1 == Some(tr.normalizeForAll), s"tpe = $tpe") } diff --git a/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala b/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala index 164b0bade..705959d40 100644 --- a/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala @@ -4,7 +4,10 @@ import cats.data.{NonEmptyList, State, Writer} import cats.implicits._ import org.scalacheck.{Arbitrary, Gen} import org.scalatest.funsuite.AnyFunSuite -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import scala.collection.immutable.SortedSet import Arbitrary.arbitrary @@ -15,21 +18,21 @@ import rankn.{Type, NTypeGen} class TypedExprTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 5000) + // PropertyCheckConfiguration(minSuccessful = 5000) PropertyCheckConfiguration(minSuccessful = 500) def allVars[A](te: TypedExpr[A]): Set[Bindable] = { type W[B] = Writer[Set[Bindable], B] te.traverseUp[W] { - case v@TypedExpr.Local(ident, _, _) => Writer(Set(ident), v) - case notVar => Writer(Set.empty, notVar) - }.run._1 + case v @ TypedExpr.Local(ident, _, _) => Writer(Set(ident), v) + case notVar => Writer(Set.empty, notVar) + }.run + ._1 } - /** - * Assert two bits of code normalize to the same thing - */ + /** Assert two bits of code normalize to the same thing + */ def normSame(s1: String, s2: String) = checkLast(s1) { t1 => checkLast(s2) { t2 => @@ -42,7 +45,10 @@ class TypedExprTest extends AnyFunSuite { val frees = TypedExpr.freeVarsSet(te :: Nil).toSet val av = allVars(te) val missing = frees -- av - assert(missing.isEmpty, s"expression:\n\n${te.repr}\n\nallVars: $av\n\nfrees: $frees") + assert( + missing.isEmpty, + s"expression:\n\n${te.repr}\n\nallVars: $av\n\nfrees: $frees" + ) } forAll(genTypedExpr)(law _) @@ -97,16 +103,19 @@ y = match x: case notLit => fail(s"expected Literal got: ${notLit.repr}") } - normSame("""# + normSame( + """# struct Tup2(a, b) x = 23 x = Tup2(1, 2) y = match x: case Tup2(a, _): a -""", """# +""", + """# y = 1 -""") +""" + ) checkLast("""# struct Tup2(a, b) @@ -234,7 +243,8 @@ y = match x: } test("we can lift a match above a lambda") { - normSame("""# + normSame( + """# struct Tup2(a, b) y = Tup2(1, 2) @@ -242,12 +252,15 @@ y = Tup2(1, 2) def inner_match(x): match y: case Tup2(a, _): Tup2(a, x) -""", """# +""", + """# struct Tup2(a, b) inner_match = x -> Tup2(1, x) -""") +""" + ) - normSame("""# + normSame( + """# struct Tup2(a, b) enum Eith: L(left), R(right) @@ -258,7 +271,8 @@ def run(y): case R(b): Tup2(x, b) inner_match -""", """# +""", + """# struct Tup2(a, b) enum Eith: L(left), R(right) @@ -266,11 +280,13 @@ def run(y): match y: case L(a): x -> Tup2(a, x) case R(b): x -> Tup2(x, b) -""") +""" + ) } test("we can push lets into match") { - normSame("""# + normSame( + """# struct Tup2(a, b) enum Eith: L(left), R(right) @@ -279,7 +295,8 @@ def run(y, x): match x: L(_): z R(r): r -""", """# +""", + """# struct Tup2(a, b) enum Eith: L(left), R(right) @@ -287,35 +304,45 @@ def run(y, x): match x: L(_): y R(r): r -""") +""" + ) } test("we can evaluate constant matches") { - normSame("""# + normSame( + """# x = match 1: case (1 | 2) as x: x case _: -1 -""", """# +""", + """# x = 1 -""") +""" + ) - normSame("""# + normSame( + """# x = match 1: case _: -1 -""", """# +""", + """# x = -1 -""") +""" + ) - normSame("""# + normSame( + """# y = 21 def foo(_): match y: case 42: 0 case x: x -""", """# +""", + """# foo = _ -> 21 -""") +""" + ) /* * This does not yet work @@ -333,7 +360,7 @@ def foo(_): """, """# foo = _ -> 1 """) - */ + */ } val intTpe = Type.IntType @@ -347,53 +374,77 @@ foo = _ -> 1 def varTE(n: String, tpe: Type): TypedExpr[Unit] = TypedExpr.Local(Identifier.Name(n), tpe, ()) - def let(n: String, ex1: TypedExpr[Unit], ex2: TypedExpr[Unit]): TypedExpr[Unit] = + def let( + n: String, + ex1: TypedExpr[Unit], + ex2: TypedExpr[Unit] + ): TypedExpr[Unit] = TypedExpr.Let(Identifier.Name(n), ex1, ex2, RecursionKind.NonRecursive, ()) - def letrec(n: String, ex1: TypedExpr[Unit], ex2: TypedExpr[Unit]): TypedExpr[Unit] = + def letrec( + n: String, + ex1: TypedExpr[Unit], + ex2: TypedExpr[Unit] + ): TypedExpr[Unit] = TypedExpr.Let(Identifier.Name(n), ex1, ex2, RecursionKind.Recursive, ()) - def app(fn: TypedExpr[Unit], arg: TypedExpr[Unit], tpe: Type): TypedExpr[Unit] = + def app( + fn: TypedExpr[Unit], + arg: TypedExpr[Unit], + tpe: Type + ): TypedExpr[Unit] = TypedExpr.App(fn, NonEmptyList.one(arg), tpe, ()) - def lam(n: String, nt: Type, res: TypedExpr[Unit]): TypedExpr[Unit] = - TypedExpr.AnnotatedLambda(NonEmptyList.one((Identifier.Name(n), nt)), res, ()) + def lam(n: String, nt: Type, res: TypedExpr[Unit]): TypedExpr[Unit] = + TypedExpr.AnnotatedLambda( + NonEmptyList.one((Identifier.Name(n), nt)), + res, + () + ) test("test let substitution") { { // substitution in let val let1 = let("y", varTE("x", intTpe), varTE("y", intTpe)) - assert(TypedExpr.substitute(Identifier.Name("x"), int(2), let1) == - Some(let("y", int(2), varTE("y", intTpe)))) + assert( + TypedExpr.substitute(Identifier.Name("x"), int(2), let1) == + Some(let("y", int(2), varTE("y", intTpe))) + ) } { // substitution in let with a masking val let1 = let("y", varTE("x", intTpe), varTE("y", intTpe)) - assert(TypedExpr.substitute(Identifier.Name("x"), varTE("y", intTpe), let1) == - None) + assert( + TypedExpr.substitute(Identifier.Name("x"), varTE("y", intTpe), let1) == + None + ) } { // substitution in let with a shadowing in result val let1 = let("y", varTE("y", intTpe), varTE("y", intTpe)) - assert(TypedExpr.substitute(Identifier.Name("y"), int(42), let1) == - Some(let("y", int(42), varTE("y", intTpe)))) + assert( + TypedExpr.substitute(Identifier.Name("y"), int(42), let1) == + Some(let("y", int(42), varTE("y", intTpe))) + ) } { // substitution in letrec with a shadowing in bind and result val let1 = letrec("y", varTE("y", intTpe), varTE("y", intTpe)) - assert(TypedExpr.substitute(Identifier.Name("y"), int(42), let1) == - Some(let1)) + assert( + TypedExpr.substitute(Identifier.Name("y"), int(42), let1) == + Some(let1) + ) } } lazy val genNonFree: Gen[TypedExpr[Unit]] = - genTypedExpr.flatMap { te => - if (TypedExpr.freeVars(te :: Nil).isEmpty) Gen.const(te) - else genNonFree - } + genTypedExpr.flatMap { te => + if (TypedExpr.freeVars(te :: Nil).isEmpty) Gen.const(te) + else genNonFree + } test("after substitution, a variable is no longer free") { forAll(genTypedExpr, genNonFree) { (te0, te1) => @@ -418,7 +469,7 @@ foo = _ -> 1 lazy val nf: Gen[Bindable] = Generators.bindIdentGen.flatMap { case isfree if frees(isfree) => nf - case notfree => Gen.const(notfree) + case notfree => Gen.const(notfree) } nf @@ -430,28 +481,40 @@ foo = _ -> 1 } yield (nf, te) forAll(pair, genNonFree) { case ((b, te0), te1) => - TypedExpr.substitute(b, te1, te0) match { - case None => - // te1 has no free variables, this shouldn't fail - assert(false) + TypedExpr.substitute(b, te1, te0) match { + case None => + // te1 has no free variables, this shouldn't fail + assert(false) - case Some(te0sub) => assert(te0sub == te0) - } + case Some(te0sub) => assert(te0sub == te0) + } } } - test("let x = y in x == y") { // inline lets of vars - assert(TypedExprNormalization.normalize(let("x", varTE("y", intTpe), varTE("x", intTpe))) == - Some(varTE("y", intTpe))) + assert( + TypedExprNormalization.normalize( + let("x", varTE("y", intTpe), varTE("x", intTpe)) + ) == + Some(varTE("y", intTpe)) + ) } val normalLet = - let("x", varTE("y", intTpe), - let("y", app(varTE("z", intTpe), int(43), intTpe), - app(app(varTE("x", intTpe), varTE("y", intTpe), intTpe), - varTE("y", intTpe), intTpe))) + let( + "x", + varTE("y", intTpe), + let( + "y", + app(varTE("z", intTpe), int(43), intTpe), + app( + app(varTE("x", intTpe), varTE("y", intTpe), intTpe), + varTE("y", intTpe), + intTpe + ) + ) + ) test("we can't inline using a shadow: let x = y in let y = z in x(y, y)") { // we can't inline a shadow @@ -462,20 +525,37 @@ foo = _ -> 1 } test("if w doesn't have x free: (app (let x y z) w) == let x y (app z w)") { - assert(TypedExprNormalization.normalize(app(normalLet, varTE("w", intTpe), intTpe)) == - Some( - let("x", varTE("y", intTpe), - let("y", app(varTE("z", intTpe), int(43), intTpe), - app(app(app(varTE("x", intTpe), varTE("y", intTpe), intTpe), - varTE("y", intTpe), intTpe), - varTE("w", intTpe), intTpe))))) + assert( + TypedExprNormalization.normalize( + app(normalLet, varTE("w", intTpe), intTpe) + ) == + Some( + let( + "x", + varTE("y", intTpe), + let( + "y", + app(varTE("z", intTpe), int(43), intTpe), + app( + app( + app(varTE("x", intTpe), varTE("y", intTpe), intTpe), + varTE("y", intTpe), + intTpe + ), + varTE("w", intTpe), + intTpe + ) + ) + ) + ) + ) } test("x -> f(x) == f") { val f = varTE("f", Type.Fun(intTpe, intTpe)) val left = lam("x", intTpe, app(f, varTE("x", intTpe), intTpe)) - + assert(TypedExprNormalization.normalize(left) == Some(f)) checkLast(""" @@ -506,7 +586,8 @@ x = Foo val int2int = Type.Fun(intTpe, intTpe) val f = varTE("f", Type.Fun(intTpe, int2int)) val z = varTE("z", intTpe) - val lamf = lam("x", intTpe, app(app(f, varTE("x", intTpe), int2int), z, intTpe)) + val lamf = + lam("x", intTpe, app(app(f, varTE("x", intTpe), int2int), z, intTpe)) val y = varTE("y", intTpe) val left = app(lamf, y, intTpe) val right = app(app(f, y, int2int), z, intTpe) @@ -519,7 +600,6 @@ f = (_, y) -> y z = 1 res = y -> (x -> f(x, z))(y) """) { te1 => - checkLast(""" f = (_, y) -> y res = y -> f(y, 1) @@ -543,7 +623,7 @@ fn = ( ) ) """) { te1 => - checkLast(""" + checkLast(""" enum FooBar: Foo, Bar fn = (x: FooBar) -> x @@ -563,7 +643,7 @@ x = ( c ) """) { te1 => - checkLast(""" + checkLast(""" enum FooBar: Foo, Bar x = Foo @@ -587,7 +667,9 @@ x = Foo test("TypedExpr.substituteTypeVar of identity is identity") { forAll(genTypedExpr, Gen.listOf(NTypeGen.genBound)) { (te, bounds) => - val identMap: Map[Type.Var, Type] = bounds.map { b => (b, Type.TyVar(b)) }.toMap + val identMap: Map[Type.Var, Type] = bounds.map { b => + (b, Type.TyVar(b)) + }.toMap assert(TypedExpr.substituteTypeVar(te, identMap) == te) } } @@ -596,9 +678,12 @@ x = Foo forAll(genTypedExpr, Gen.listOf(NTypeGen.genBound)) { (te, bounds) => val tpes = te.allTypes val avoid = tpes.toSet | bounds.map(Type.TyVar(_)).toSet - val replacements = Type.allBinders.iterator.filterNot { t => avoid(Type.TyVar(t)) } + val replacements = Type.allBinders.iterator.filterNot { t => + avoid(Type.TyVar(t)) + } val identMap: Map[Type.Var, Type] = - bounds.iterator.zip(replacements) + bounds.iterator + .zip(replacements) .map { case (b, v) => (b, Type.TyVar(v)) } .toMap val te1 = TypedExpr.substituteTypeVar(te, identMap) @@ -610,25 +695,33 @@ x = Foo test("TypedExpr.substituteTypeVar is not an identity function") { // if we replace all the current types with some bound types, things won't be the same forAll(genTypedExpr) { te => - val tpes: Set[Type.Var] = te.allTypes.iterator.collect { case Type.TyVar(b) => b }.toSet + val tpes: Set[Type.Var] = te.allTypes.iterator.collect { + case Type.TyVar(b) => b + }.toSet implicit def setM[A: Ordering]: cats.Monoid[SortedSet[A]] = new cats.Monoid[SortedSet[A]] { def empty = SortedSet.empty def combine(a: SortedSet[A], b: SortedSet[A]) = a ++ b - } + } // All the vars that are used in bounds - val bounds: Set[Type.Var] = te.traverseType { (t: Type) => - t match { - case Type.ForAll(ps, _) => Writer(SortedSet[Type.Var](ps.toList.map(_._1): _*), t) - case _ => Writer(SortedSet[Type.Var](), t) + val bounds: Set[Type.Var] = te + .traverseType { (t: Type) => + t match { + case Type.ForAll(ps, _) => + Writer(SortedSet[Type.Var](ps.toList.map(_._1): _*), t) + case _ => Writer(SortedSet[Type.Var](), t) + } } - }.run._1.toSet[Type.Var] + .run + ._1 + .toSet[Type.Var] val replacements = Type.allBinders.iterator.filterNot(tpes) val identMap: Map[Type.Var, Type] = - tpes.filterNot(bounds) + tpes + .filterNot(bounds) .iterator .zip(replacements) .map { case (b, v) => (b, Type.TyVar(v)) } @@ -647,7 +740,9 @@ x = Foo } } - def count[A](te: TypedExpr[A])(fn: PartialFunction[TypedExpr[A], Boolean]): Int = { + def count[A]( + te: TypedExpr[A] + )(fn: PartialFunction[TypedExpr[A], Boolean]): Int = { type W[B] = Writer[Int, B] val (count, _) = te.traverseUp[W] { inner => @@ -658,24 +753,25 @@ x = Foo count } - def countMatch[A](te: TypedExpr[A]) = count(te) { case TypedExpr.Match(_, _, _) => true } - def countLet[A](te: TypedExpr[A]) = count(te) { case TypedExpr.Let(_, _, _, _, _) => true } + def countMatch[A](te: TypedExpr[A]) = count(te) { + case TypedExpr.Match(_, _, _) => true + } + def countLet[A](te: TypedExpr[A]) = count(te) { + case TypedExpr.Let(_, _, _, _, _) => true + } test("test match removed from some examples") { - checkLast( - """ + checkLast(""" x = _ -> 1 """) { te => assert(countMatch(te) == 0) } - checkLast( - """ + checkLast(""" x = 10 y = match x: case z: z """) { te => assert(countMatch(te) == 0) } - checkLast( - """ + checkLast(""" x = 10 y = match x: case _: 20 @@ -684,15 +780,13 @@ y = match x: test("test let removed from some examples") { // this should turn into `y = 20` as the last expression - checkLast( - """ + checkLast(""" x = 10 y = match x: case _: 20 """) { te => assert(countLet(te) == 0) } - checkLast( - """ + checkLast(""" foo = ( x = 1 _ = x @@ -723,7 +817,9 @@ foo = ( } def lawR[A, B](te: TypedExpr[B], a: A)(fn: (B, A) => A) = { - val viaFold = te.foldRight(cats.Eval.now(a)) { (b, r) => r.map { j => fn(b, j) } }.value + val viaFold = te + .foldRight(cats.Eval.now(a)) { (b, r) => r.map { j => fn(b, j) } } + .value val viaTraverse: State[A, Unit] = te.traverse_[State[A, *], Unit] { b => for { i <- State.get[A] @@ -735,7 +831,6 @@ foo = ( assert(viaFold == viaTraverse.runS(a).value, s"${te.repr}") } - forAll(genTypedExprInt, Gen.choose(0, 1000)) { (te, init) => // make a commutative int function law(init, te) { (a, b) => (a + 1) * b } @@ -750,9 +845,11 @@ foo = ( def law[A, B: Monoid](te: TypedExpr[A])(fn: A => B) = { val viaFold = te.foldMap(fn) - val viaTraverse: Const[B, Unit] = te.traverse[Const[B, *], Unit] { b => - Const[B, Unit](fn(b)) - }.void + val viaTraverse: Const[B, Unit] = te + .traverse[Const[B, *], Unit] { b => + Const[B, Unit](fn(b)) + } + .void assert(viaFold == viaTraverse.getConst, s"${te.repr}") } @@ -761,11 +858,12 @@ foo = ( // non-commutative forAll(genTypedExprChar, arbitrary[Char => String])(law(_)(_)) - val lamconst: TypedExpr[String] = + val lamconst: TypedExpr[String] = TypedExpr.AnnotatedLambda( NonEmptyList.one((Identifier.Name("x"), intTpe)), int(1).as("a"), - "b") + "b" + ) assert(lamconst.foldMap(identity) == "ab") assert(lamconst.traverse { a => Const[String, Unit](a) }.getConst == "ab") @@ -774,15 +872,17 @@ foo = ( test("TypedExpr.traverse.void matches traverse_") { import cats.data.Const forAll(genTypedExprInt, arbitrary[Int => String]) { (te, fn) => - assert(te.traverse { i => Const[String, Unit](fn(i)) }.void == - te.traverse_ { i => Const[String, Unit](fn(i)) }) + assert( + te.traverse { i => Const[String, Unit](fn(i)) }.void == + te.traverse_ { i => Const[String, Unit](fn(i)) } + ) } } test("TypedExpr.foldRight matches foldRight for commutative funs") { forAll(genTypedExprInt, Gen.choose(0, 1000)) { (te, init) => - - val right = te.foldRight(cats.Eval.now(init)) { (i, ej) => ej.map(_ + i) }.value + val right = + te.foldRight(cats.Eval.now(init)) { (i, ej) => ej.map(_ + i) }.value val left = te.foldLeft(init)(_ + _) assert(right == left) } @@ -790,8 +890,11 @@ foo = ( test("TypedExpr.foldRight matches foldRight for non-commutative funs") { forAll(genTypedExprInt) { te => - - val right = te.foldRight(cats.Eval.now("")) { (i, ej) => ej.map { j => i.toString + j } }.value + val right = te + .foldRight(cats.Eval.now("")) { (i, ej) => + ej.map { j => i.toString + j } + } + .value val left = te.foldLeft("") { (i, j) => i + j.toString } assert(right == left) } diff --git a/core/src/test/scala/org/bykn/bosatsu/ValueTest.scala b/core/src/test/scala/org/bykn/bosatsu/ValueTest.scala index b022037d7..b3898c3e3 100644 --- a/core/src/test/scala/org/bykn/bosatsu/ValueTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/ValueTest.scala @@ -1,7 +1,10 @@ package org.bykn.bosatsu import org.scalacheck.{Arbitrary, Gen} -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import Value._ import org.scalatest.funsuite.AnyFunSuite @@ -9,7 +12,7 @@ class ValueTest extends AnyFunSuite { import GenValue.genValue implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 5000) + // PropertyCheckConfiguration(minSuccessful = 5000) PropertyCheckConfiguration(minSuccessful = 500) test("SumValue.toString is what we expect") { @@ -29,7 +32,7 @@ class ValueTest extends AnyFunSuite { forAll(genValue) { v => VOption.some(v) match { case VOption(Some(v1)) => assert(v1 == v) - case other => fail(s"expected Some($v) got $other") + case other => fail(s"expected Some($v) got $other") } } @@ -50,7 +53,7 @@ class ValueTest extends AnyFunSuite { forAll(Gen.listOf(genValue)) { vs => VList(vs) match { case VList(vs1) => assert(vs1 == vs) - case other => fail(s"expected VList($vs) got $other") + case other => fail(s"expected VList($vs) got $other") } } diff --git a/core/src/test/scala/org/bykn/bosatsu/ValueToDocTest.scala b/core/src/test/scala/org/bykn/bosatsu/ValueToDocTest.scala index a9d25ee70..4fd7ab235 100644 --- a/core/src/test/scala/org/bykn/bosatsu/ValueToDocTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/ValueToDocTest.scala @@ -1,7 +1,10 @@ package org.bykn.bosatsu import org.scalacheck.Gen -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import rankn.{NTypeGen, Type, TypeEnv} import TestUtils.typeEnvOf @@ -10,7 +13,9 @@ import org.scalatest.funsuite.AnyFunSuite class ValueToDocTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - PropertyCheckConfiguration(minSuccessful = if (Platform.isScalaJvm) 1000 else 20) + PropertyCheckConfiguration(minSuccessful = + if (Platform.isScalaJvm) 1000 else 20 + ) test("never throw when converting to doc") { val tegen = Generators.typeEnvGen(PackageName.parts("Foo"), Gen.const(())) @@ -19,12 +24,14 @@ class ValueToDocTest extends AnyFunSuite { tegen.flatMap { te => val tyconsts = te.allDefinedTypes.map(_.toTypeConst) - val theseTypes = NTypeGen.genDepth(4, if (tyconsts.isEmpty) None else Some(Gen.oneOf(tyconsts))) + val theseTypes = NTypeGen.genDepth( + 4, + if (tyconsts.isEmpty) None else Some(Gen.oneOf(tyconsts)) + ) theseTypes.map((te, _)) } - forAll(withType, GenValue.genValue) { case ((te, t), v) => val vd = ValueToDoc(te.toDefinedType(_)) vd.toDoc(t)(v) @@ -33,7 +40,9 @@ class ValueToDocTest extends AnyFunSuite { } test("some hand written cases round trip") { - val te = typeEnvOf(PackageName.parts("Test"), """ + val te = typeEnvOf( + PackageName.parts("Test"), + """ struct MyUnit # wrappers are removed @@ -43,19 +52,21 @@ struct MyPair(fst, snd) enum MyEither: L(left), R(right) enum MyNat: Z, S(prev: MyNat) -""") +""" + ) val conv = ValueToDoc(te.toDefinedType(_)) def stringToType(t: String): Type = { val tr = Parser.unsafeParse(TypeRef.parser, t) TypeRefConverter[cats.Id](tr) { cons => - te.referencedPackages.toList.flatMap { pack => - val const = Type.Const.Defined(pack, TypeName(cons)) - te.toDefinedType(const).map(_ => const) - } - .headOption - .getOrElse(Type.Const.predef(cons.asString)) + te.referencedPackages.toList + .flatMap { pack => + val const = Type.Const.Defined(pack, TypeName(cons)) + te.toDefinedType(const).map(_ => const) + } + .headOption + .getOrElse(Type.Const.predef(cons.asString)) } } @@ -65,7 +76,7 @@ enum MyNat: Z, S(prev: MyNat) toDoc(v) match { case Right(doc) => assert(doc.render(80) == str) - case Left(err) => fail(s"could not handle to Value: $tpe, $v, $err") + case Left(err) => fail(s"could not handle to Value: $tpe, $v, $err") } } @@ -73,10 +84,25 @@ enum MyNat: Z, S(prev: MyNat) law("String", Value.Str("hello world"), "'hello world'") law("MyUnit", Value.UnitValue, "MyUnit") law("MyWrapper[MyUnit]", Value.UnitValue, "MyWrapper { item: MyUnit }") - law("MyWrapper[MyWrapper[MyUnit]]", Value.UnitValue, "MyWrapper { item: MyWrapper { item: MyUnit } }") - law("MyPair[MyUnit, MyUnit]", Value.ProductValue.fromList(List(Value.UnitValue, Value.UnitValue)), - "MyPair { fst: MyUnit, snd: MyUnit }") - law("MyEither[MyUnit, MyUnit]", Value.SumValue(0, Value.ProductValue.fromList(List(Value.UnitValue))), "L { left: MyUnit }") - law("MyEither[MyUnit, MyUnit]", Value.SumValue(1, Value.ProductValue.fromList(List(Value.UnitValue))), "R { right: MyUnit }") + law( + "MyWrapper[MyWrapper[MyUnit]]", + Value.UnitValue, + "MyWrapper { item: MyWrapper { item: MyUnit } }" + ) + law( + "MyPair[MyUnit, MyUnit]", + Value.ProductValue.fromList(List(Value.UnitValue, Value.UnitValue)), + "MyPair { fst: MyUnit, snd: MyUnit }" + ) + law( + "MyEither[MyUnit, MyUnit]", + Value.SumValue(0, Value.ProductValue.fromList(List(Value.UnitValue))), + "L { left: MyUnit }" + ) + law( + "MyEither[MyUnit, MyUnit]", + Value.SumValue(1, Value.ProductValue.fromList(List(Value.UnitValue))), + "R { right: MyUnit }" + ) } } diff --git a/core/src/test/scala/org/bykn/bosatsu/VarianceTest.scala b/core/src/test/scala/org/bykn/bosatsu/VarianceTest.scala index df8529dc1..34afddbb8 100644 --- a/core/src/test/scala/org/bykn/bosatsu/VarianceTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/VarianceTest.scala @@ -9,7 +9,8 @@ object VarianceGen { Variance.Phantom, Variance.Contravariant, Variance.Covariant, - Variance.Invariant) + Variance.Invariant + ) implicit val arbVar: Arbitrary[Variance] = Arbitrary(gen) } @@ -27,7 +28,9 @@ class VarianceTest extends AnyFunSuite { test("variance combine is associative") { forAll { (v1: Variance, v2: Variance, v3: Variance) => - assert(V.combine(v1, V.combine(v2, v3)) == V.combine(V.combine(v1, v2), v3)) + assert( + V.combine(v1, V.combine(v2, v3)) == V.combine(V.combine(v1, v2), v3) + ) } } @@ -45,7 +48,7 @@ class VarianceTest extends AnyFunSuite { test("variance is distributive") { forAll { (v1: Variance, v2: Variance, v3: Variance) => - val left = v1 * (v2 + v3) + val left = v1 * (v2 + v3) val right = (v1 * v2) + (v1 * v3) assert(left == right, s"$left != $right") } @@ -56,7 +59,7 @@ class VarianceTest extends AnyFunSuite { val v2 = Variance.phantom val v3 = Variance.co - val left = v1 * (v2 + v3) + val left = v1 * (v2 + v3) val right = (v1 * v2) + (v1 * v3) assert(left == right, s"$left != $right") } @@ -112,7 +115,12 @@ class VarianceTest extends AnyFunSuite { } test("covariant combines to get either covariant or invariant") { - assert(V.combine(Variance.Covariant, Variance.Contravariant) == Variance.Invariant) + assert( + V.combine( + Variance.Covariant, + Variance.Contravariant + ) == Variance.Invariant + ) val results = Set(Variance.co, Variance.in) forAll { (v1: Variance) => assert(results(V.combine(v1, Variance.Covariant))) @@ -120,7 +128,12 @@ class VarianceTest extends AnyFunSuite { } test("contravariant combines to get either contravariant or invariant") { - assert(V.combine(Variance.Covariant, Variance.Contravariant) == Variance.Invariant) + assert( + V.combine( + Variance.Covariant, + Variance.Contravariant + ) == Variance.Invariant + ) val results = Set(Variance.contra, Variance.in) forAll { (v1: Variance) => assert(results(V.combine(v1, Variance.Contravariant))) diff --git a/core/src/test/scala/org/bykn/bosatsu/codegen/python/PythonGenTest.scala b/core/src/test/scala/org/bykn/bosatsu/codegen/python/PythonGenTest.scala index 7ef56c142..a00b62edd 100644 --- a/core/src/test/scala/org/bykn/bosatsu/codegen/python/PythonGenTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/codegen/python/PythonGenTest.scala @@ -2,14 +2,17 @@ package org.bykn.bosatsu.codegen.python import org.bykn.bosatsu.Identifier.{Bindable, unsafeBindable} import org.bykn.bosatsu.Generators.bindIdentGen -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.scalatest.funsuite.AnyFunSuite class PythonGenTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 50000) + // PropertyCheckConfiguration(minSuccessful = 50000) PropertyCheckConfiguration(minSuccessful = 5000) - //PropertyCheckConfiguration(minSuccessful = 500) + // PropertyCheckConfiguration(minSuccessful = 500) test("PythonGen.escape round trips") { @@ -17,16 +20,14 @@ class PythonGenTest extends AnyFunSuite { val ident = PythonGen.escape(b) PythonGen.unescape(ident) match { case Some(b1) => assert(b1.asString == b.asString) - case None => assert(false, s"$b => $ident could not round trip") + case None => assert(false, s"$b => $ident could not round trip") } } forAll(bindIdentGen)(law(_)) val examples: List[Bindable] = - List( - "`12 =_=`", - "`N`").map(unsafeBindable) + List("`12 =_=`", "`N`").map(unsafeBindable) examples.foreach(law(_)) @@ -38,7 +39,10 @@ class PythonGenTest extends AnyFunSuite { forAll(bindIdentGen) { b => val str = PythonGen.escape(b).name - assert(PythonName.matcher(str).matches(), s"escaped: ${b.sourceCodeRepr} to $str") + assert( + PythonName.matcher(str).matches(), + s"escaped: ${b.sourceCodeRepr} to $str" + ) } } diff --git a/core/src/test/scala/org/bykn/bosatsu/graph/ToposortTest.scala b/core/src/test/scala/org/bykn/bosatsu/graph/ToposortTest.scala index b23f6bdf6..0bbcd96b9 100644 --- a/core/src/test/scala/org/bykn/bosatsu/graph/ToposortTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/graph/ToposortTest.scala @@ -3,14 +3,17 @@ package org.bykn.bosatsu.graph import cats.Order import cats.data.NonEmptyList import org.scalacheck.Gen -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import cats.implicits._ import org.scalatest.funsuite.AnyFunSuite class ToposortTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 5000) + // PropertyCheckConfiguration(minSuccessful = 5000) PropertyCheckConfiguration(minSuccessful = 1000) test("toposort can recover full sort") { @@ -25,7 +28,11 @@ class ToposortTest extends AnyFunSuite { assert(res.isSuccess) assert(res.isFailure == res.loopNodes.nonEmpty) assert(res.toSuccess == Some(res.layers)) - assert(res.layers == items.toVector.sorted(Order[A].toOrdering).map(NonEmptyList(_, Nil))) + assert( + res.layers == items.toVector + .sorted(Order[A].toOrdering) + .map(NonEmptyList(_, Nil)) + ) assert(res.layersAreTotalOrder) } @@ -43,7 +50,10 @@ class ToposortTest extends AnyFunSuite { val nset = fn(n).toSet if (nset.nonEmpty) { (id until layers.size).foreach { id1 => - assert(layers(id1).filter(nset).isEmpty, s"node $n in layer $id has points to later layers: $id1") + assert( + layers(id1).filter(nset).isEmpty, + s"node $n in layer $id has points to later layers: $id1" + ) } } } @@ -59,13 +69,16 @@ class ToposortTest extends AnyFunSuite { val nid = Gen.choose(0, 100) val pair = for { n <- nid - neighbor <- Gen.listOf(nid).map(_.filter(_ < n).distinct) // make sure it is a dag + neighbor <- Gen + .listOf(nid) + .map(_.filter(_ < n).distinct) // make sure it is a dag } yield (n, neighbor) val genDag = Gen.mapOf(pair).map(Dag(_)) forAll(genDag) { case Dag(graph) => val allNodes = graph.flatMap { case (h, t) => h :: t }.toSet - val Toposort.Success(sorted, _) = Toposort.sort(allNodes)(graph.getOrElse(_, Nil)) + val Toposort.Success(sorted, _) = + Toposort.sort(allNodes)(graph.getOrElse(_, Nil)) assert(sorted.flatMap(_.toList).sorted == allNodes.toList.sorted) noEdgesToLater(sorted)(n => graph.getOrElse(n, Nil)) layersAreSorted(sorted) @@ -87,7 +100,9 @@ class ToposortTest extends AnyFunSuite { layersAreSorted(layers) // all the nodes is the same set: val goodNodes = layers.flatMap(_.toList) - assert((goodNodes.toList ::: res.loopNodes).sorted == allNodes.toList.sorted) + assert( + (goodNodes.toList ::: res.loopNodes).sorted == allNodes.toList.sorted + ) // good nodes are distinct assert(goodNodes == goodNodes.distinct) // loop nodes are distinct @@ -103,7 +118,16 @@ class ToposortTest extends AnyFunSuite { } test("we return the least node with a loop") { - assert(Toposort.sort(List(1, 2))(Function.const(List(1, 2))) == Toposort.Failure(List(1, 2), Vector.empty)) - assert(Toposort.sort(List("bb", "aa"))(Function.const(List("aa", "bb"))) == Toposort.Failure(List("aa", "bb"), Vector.empty)) + assert( + Toposort.sort(List(1, 2))(Function.const(List(1, 2))) == Toposort.Failure( + List(1, 2), + Vector.empty + ) + ) + assert( + Toposort.sort(List("bb", "aa"))( + Function.const(List("aa", "bb")) + ) == Toposort.Failure(List("aa", "bb"), Vector.empty) + ) } } diff --git a/core/src/test/scala/org/bykn/bosatsu/graph/TreeTest.scala b/core/src/test/scala/org/bykn/bosatsu/graph/TreeTest.scala index f077064cc..2b80c71de 100644 --- a/core/src/test/scala/org/bykn/bosatsu/graph/TreeTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/graph/TreeTest.scala @@ -10,26 +10,27 @@ import org.scalatest.funsuite.AnyFunSuite class TreeTest extends AnyFunSuite { test("explicit dags never fail") { - val dagFn: Gen[Int => List[Int]] = + val dagFn: Gen[Int => List[Int]] = Gen.choose(1L, Long.MaxValue).map { seed => - val rng = new java.util.Random(seed) val cache = scala.collection.mutable.Map[Int, List[Int]]() { (node: Int) => // the expected number of neighbors is 1.5, that means the graph is expected to be finite - cache.getOrElseUpdate(node, { - val count = rng.nextInt(3) - (node + 1 until (node + count + 1)).toList.filter(_ > node) - }) + cache.getOrElseUpdate( + node, { + val count = rng.nextInt(3) + (node + 1 until (node + count + 1)).toList.filter(_ > node) + } + ) } } forAll(Gen.choose(0, Int.MaxValue), dagFn) { (start, nfn) => Tree.dagToTree(start)(nfn) match { - case v@Validated.Valid(tree) => + case v @ Validated.Valid(tree) => // the neightbor function should give the same tree: val treeFn = Tree.neighborsFn(tree) val tree2 = Tree.dagToTree(tree.item)(treeFn) @@ -42,9 +43,10 @@ class TreeTest extends AnyFunSuite { } test("circular graphs are invalid") { - val prime = Gen.oneOf(2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89) + val prime = Gen.oneOf(2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, + 47, 53, 59, 61, 67, 71, 73, 79, 83, 89) - val dagFn: Gen[(Int, Int => List[Int])] = + val dagFn: Gen[(Int, Int => List[Int])] = for { p <- prime p1 = p - 1 @@ -54,10 +56,13 @@ class TreeTest extends AnyFunSuite { b <- nodeGen } yield { - (init, { (node: Int) => - // only 1 neighbor, but this is in a cyclic group so it can't be a dag - List((node * a + b) % p) - }) + ( + init, + { (node: Int) => + // only 1 neighbor, but this is in a cyclic group so it can't be a dag + List((node * a + b) % p) + } + ) } forAll(dagFn) { case (start, nfn) => @@ -97,7 +102,8 @@ class TreeTest extends AnyFunSuite { NonEmptyList.fromList(l1.filterNot(nel0.toList.toSet)) match { case None => succeed case Some(diffs) => - val got = Tree.distinctBy(nel0)(identity) ::: Tree.distinctBy(diffs)(identity) + val got = + Tree.distinctBy(nel0)(identity) ::: Tree.distinctBy(diffs)(identity) val expected = Tree.distinctBy(nel0 ::: diffs)(identity) assert(got == expected) } diff --git a/core/src/test/scala/org/bykn/bosatsu/pattern/SeqPatternTest.scala b/core/src/test/scala/org/bykn/bosatsu/pattern/SeqPatternTest.scala index fb78105ed..5736af319 100644 --- a/core/src/test/scala/org/bykn/bosatsu/pattern/SeqPatternTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/pattern/SeqPatternTest.scala @@ -1,7 +1,10 @@ package org.bykn.bosatsu.pattern import org.scalacheck.{Arbitrary, Gen, Shrink} -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import SeqPattern.{Cat, Empty} import SeqPart.{Wildcard, AnyElem, Lit} @@ -28,7 +31,8 @@ object StringSeqPatternGen { Gen.frequency( (15, Gen.oneOf(Lit('0'), Lit('1'))), (2, Gen.const(AnyElem)), - (1, Gen.const(Wildcard))) + (1, Gen.const(Wildcard)) + ) } val genPat: Gen[SeqPattern[Char]] = { @@ -38,9 +42,7 @@ object StringSeqPatternGen { t <- genPat } yield Cat(h, t) - Gen.frequency( - (1, Gen.const(Empty)), - (5, cat)) + Gen.frequency((1, Gen.const(Empty)), (5, cat)) } implicit val arbPattern: Arbitrary[SeqPattern[Char]] = Arbitrary(genPat) @@ -54,7 +56,10 @@ object StringSeqPatternGen { t #:: shrinkPat.shrink(t) } - def genNamedFn[A](gp: Gen[SeqPart[A]], nextId: Int): Gen[(Int, NamedSeqPattern[A])] = { + def genNamedFn[A]( + gp: Gen[SeqPart[A]], + nextId: Int + ): Gen[(Int, NamedSeqPattern[A])] = { lazy val recur = Gen.lzy(res) lazy val genNm: Gen[(Int, Named.Bind[A])] = @@ -68,13 +73,17 @@ object StringSeqPatternGen { // L = (4/9) / (1 - (2/3 + 1/9)) = 4 / (9 - 5) = 1 lazy val res: Gen[(Int, NamedSeqPattern[A])] = Gen.frequency( - (3, for { - (i0, n0) <- recur - (i1, n1) <- genNamedFn(gp, i0) - } yield (i1, NamedSeqPattern.NCat(n0, n1))), + ( + 3, + for { + (i0, n0) <- recur + (i1, n1) <- genNamedFn(gp, i0) + } yield (i1, NamedSeqPattern.NCat(n0, n1)) + ), (1, genNm), (1, Gen.const((nextId, NamedSeqPattern.NEmpty))), - (4, gp.map { p => (nextId, NamedSeqPattern.NSeqPart(p)) })) + (4, gp.map { p => (nextId, NamedSeqPattern.NSeqPart(p)) }) + ) res } @@ -85,7 +94,9 @@ object StringSeqPatternGen { implicit val arbNamed: Arbitrary[NamedSeqPattern[Char]] = Arbitrary(genNamed) def interleave[A](s1: Stream[A], s2: Stream[A]): Stream[A] = - if (s1.isEmpty) s2 else if (s2.isEmpty) s1 else { + if (s1.isEmpty) s2 + else if (s2.isEmpty) s1 + else { s1.head #:: interleave(s2, s1.tail) } @@ -102,32 +113,34 @@ object StringSeqPatternGen { val sp = p match { case Wildcard => AnyElem #:: tail - case AnyElem => tail - case Lit(_) => Stream.Empty + case AnyElem => tail + case Lit(_) => Stream.Empty } sp.map(NamedSeqPattern.NSeqPart(_)) case NamedSeqPattern.NCat(fst, snd) => val s1 = shrinkNamedSeqPattern.shrink(fst) val s2 = shrinkNamedSeqPattern.shrink(snd) - interleave(s1, s2).iterator.sliding(2).map { - case Seq(a, b) => NamedSeqPattern.NCat(a, b) - case _ => NamedSeqPattern.NEmpty - } - .toStream + interleave(s1, s2).iterator + .sliding(2) + .map { + case Seq(a, b) => NamedSeqPattern.NCat(a, b) + case _ => NamedSeqPattern.NEmpty + } + .toStream } def unany[A](p: SeqPattern[A]): SeqPattern[A] = p match { - case Empty => Empty + case Empty => Empty case Cat(AnyElem, t) => unany(t) - case Cat(h, t) => Cat(h, unany(t)) + case Cat(h, t) => Cat(h, unany(t)) } def unwild[A](p: SeqPattern[A]): SeqPattern[A] = p match { - case Empty => Empty + case Empty => Empty case Cat(Wildcard, t) => unwild(t) - case Cat(h, t) => Cat(h, unwild(t)) + case Cat(h, t) => Cat(h, unwild(t)) } implicit val arbString: Arbitrary[String] = Arbitrary(genBitString) @@ -149,9 +162,9 @@ abstract class SeqPatternLaws[E, I, S, R] extends AnyFunSuite { val Named = NamedSeqPattern implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 50000) + // PropertyCheckConfiguration(minSuccessful = 50000) PropertyCheckConfiguration(minSuccessful = 5000) - //PropertyCheckConfiguration(minSuccessful = 50) + // PropertyCheckConfiguration(minSuccessful = 50) def genPattern: Gen[Pattern] def genNamed: Gen[Named] @@ -174,8 +187,8 @@ abstract class SeqPatternLaws[E, I, S, R] extends AnyFunSuite { val together = intersect.exists(matches(_, x)) assert(together == sep, s"n1: $n1, n2: $n2, intersection: $intersect") - //if (together != sep) sys.error(s"n1: $n1, n2: $n2, intersection: $intersect") - //else succeed + // if (together != sep) sys.error(s"n1: $n1, n2: $n2, intersection: $intersect") + // else succeed } def namedMatchesPatternLaw(n: Named, str: S) = { @@ -208,7 +221,7 @@ abstract class SeqPatternLaws[E, I, S, R] extends AnyFunSuite { if (p2.matchesAny) assert(diff == Nil) // the law we wish we had: - //if (p2.matches(s) && p1.matches(s)) assert(!diffmatch) + // if (p2.matches(s) && p1.matches(s)) assert(!diffmatch) } test("reverse is idempotent") { @@ -220,8 +233,8 @@ abstract class SeqPatternLaws[E, I, S, R] extends AnyFunSuite { test("cat.reverse == cat.reverseCat") { forAll(genPattern) { p => p match { - case Empty => assert(p.reverse == Empty) - case c@Cat(_, _) => assert(c.reverseCat == c.reverse) + case Empty => assert(p.reverse == Empty) + case c @ Cat(_, _) => assert(c.reverseCat == c.reverse) } } } @@ -229,16 +242,19 @@ abstract class SeqPatternLaws[E, I, S, R] extends AnyFunSuite { test("reverse matches the reverse string") { forAll(genPattern, genSeq) { (p: Pattern, str: S) => val rstr = splitter.fromList(splitter.toList(str).reverse) - assert(matches(p, str) == matches(p.reverse, rstr), s"p.reverse = ${p.reverse}") + assert( + matches(p, str) == matches(p.reverse, rstr), + s"p.reverse = ${p.reverse}" + ) } } test("unlit patterns match everything") { def unlit(p: Pattern): Pattern = p match { - case Empty => Empty + case Empty => Empty case Cat(Lit(_) | AnyElem, t) => unlit(t) - case Cat(h, t) => Cat(h, unlit(t)) + case Cat(h, t) => Cat(h, unlit(t)) } forAll(genPattern, genSeq) { (p0, str) => @@ -258,7 +274,7 @@ abstract class SeqPatternLaws[E, I, S, R] extends AnyFunSuite { forAll(genPattern) { p0 => val list = p0.normalize.toList list.sliding(2).foreach { - case bad@Seq(Wildcard, Wildcard) => + case bad @ Seq(Wildcard, Wildcard) => fail(s"saw adjacent: $bad in ${p0.normalize}") case _ => () } @@ -318,13 +334,17 @@ abstract class SeqPatternLaws[E, I, S, R] extends AnyFunSuite { } test("if subset(a, b) then matching a implies matching b") { - forAll(genPattern, genPattern, genSeq)(subsetConsistentWithMatchLaw(_, _, _)) + forAll(genPattern, genPattern, genSeq)( + subsetConsistentWithMatchLaw(_, _, _) + ) } def diffUBRegressions: List[(Pattern, Pattern, S)] = Nil test("difference is an upper bound") { - forAll(genPattern, genPattern, genSeq) { case (p1, p2, s) => differenceUBLaw(p1, p2, s) } + forAll(genPattern, genPattern, genSeq) { case (p1, p2, s) => + differenceUBLaw(p1, p2, s) + } diffUBRegressions.foreach { case (p1, p2, s) => differenceUBLaw(p1, p2, s) @@ -332,10 +352,11 @@ abstract class SeqPatternLaws[E, I, S, R] extends AnyFunSuite { } test("p + q match (s + t) if p.matches(s) && q.matches(t)") { - forAll(genPattern, genPattern, genSeq, genSeq) { (p: Pattern, q: Pattern, s: S, t: S) => - if (matches(p, s) && matches(q, t)) { - assert(matches(p + q, splitter.catSeqs(s :: t :: Nil))) - } + forAll(genPattern, genPattern, genSeq, genSeq) { + (p: Pattern, q: Pattern, s: S, t: S) => + if (matches(p, s) && matches(q, t)) { + assert(matches(p + q, splitter.catSeqs(s :: t :: Nil))) + } } } @@ -343,7 +364,7 @@ abstract class SeqPatternLaws[E, I, S, R] extends AnyFunSuite { forAll(genNamed, genSeq)(namedMatchesPatternLaw(_, _)) } -/* + /* test("if x - y is empty, (x + z) - (y + z) is empty") { forAll { (x0: Pattern, y0: Pattern, z0: Pattern) => val x = Pattern.fromList(x0.toList.take(3)) @@ -354,10 +375,11 @@ abstract class SeqPatternLaws[E, I, S, R] extends AnyFunSuite { } } } -*/ + */ } -class BoolSeqPatternTest extends SeqPatternLaws[Set[Boolean], Boolean, List[Boolean], Unit] { +class BoolSeqPatternTest + extends SeqPatternLaws[Set[Boolean], Boolean, List[Boolean], Unit] { implicit lazy val shrinkPat: Shrink[SeqPattern[Set[Boolean]]] = Shrink { @@ -365,7 +387,10 @@ class BoolSeqPatternTest extends SeqPatternLaws[Set[Boolean], Boolean, List[Bool case Cat(Wildcard, t) => (Cat(AnyElem, t) #:: t #:: Stream.empty).flatMap(shrinkPat.shrink) case Cat(AnyElem, t) => - (Cat(Lit(Set(false)), t) #:: Cat(Lit(Set(true)), t) #:: t #:: Stream.empty).flatMap(shrinkPat.shrink) + (Cat(Lit(Set(false)), t) #:: Cat( + Lit(Set(true)), + t + ) #:: t #:: Stream.empty).flatMap(shrinkPat.shrink) case Cat(_, t) => t #:: shrinkPat.shrink(t) } @@ -383,7 +408,8 @@ class BoolSeqPatternTest extends SeqPatternLaws[Set[Boolean], Boolean, List[Bool Gen.frequency( (15, genSetBool.map(Lit(_))), (2, Gen.const(AnyElem)), - (1, Gen.const(Wildcard))) + (1, Gen.const(Wildcard)) + ) } val genNamed: Gen[NamedSeqPattern[Set[Boolean]]] = @@ -393,29 +419,42 @@ class BoolSeqPatternTest extends SeqPatternLaws[Set[Boolean], Boolean, List[Bool val sp = Gen.frequency( (1, SeqPart.Wildcard), (2, SeqPart.AnyElem), - (8, genSetBool.map(SeqPart.Lit(_)))) + (8, genSetBool.map(SeqPart.Lit(_))) + ) Gen.frequency( (1, Gen.const(SeqPattern.Empty)), - (5, Gen.zip(sp, Gen.lzy(genPattern)).map { case (h, t) => SeqPattern.Cat(h, t) }) + ( + 5, + Gen.zip(sp, Gen.lzy(genPattern)).map { case (h, t) => + SeqPattern.Cat(h, t) + } + ) ) } def genSeq: Gen[List[Boolean]] = Gen.choose(0, 20).flatMap(Gen.listOfN(_, genBool)) - val splitter = Splitter.listSplitter(Matcher.fnMatch[Boolean]: Matcher[Set[Boolean], Boolean, Unit]) + val splitter = Splitter.listSplitter( + Matcher.fnMatch[Boolean]: Matcher[Set[Boolean], Boolean, Unit] + ) val pmatcher = SeqPattern.matcher(splitter) - def matches(p: SeqPattern[Set[Boolean]], s: List[Boolean]): Boolean = pmatcher(p)(s).isDefined - def namedMatches(p: NamedSeqPattern[Set[Boolean]], s: List[Boolean]): Boolean = + def matches(p: SeqPattern[Set[Boolean]], s: List[Boolean]): Boolean = + pmatcher(p)(s).isDefined + def namedMatches( + p: NamedSeqPattern[Set[Boolean]], + s: List[Boolean] + ): Boolean = NamedSeqPattern.matcher(splitter)(p)(s).isDefined - implicit val setOpsBool: SetOps[Set[Boolean]] = SetOps.fromFinite(List(true, false)) + implicit val setOpsBool: SetOps[Set[Boolean]] = + SetOps.fromFinite(List(true, false)) implicit val ordSet: Ordering[Set[Boolean]] = Ordering[List[Boolean]].on { (s: Set[Boolean]) => s.toList.sorted } - val setOps: SetOps[SeqPattern[Set[Boolean]]] = SeqPattern.seqPatternSetOps[Set[Boolean]] - + val setOps: SetOps[SeqPattern[Set[Boolean]]] = + SeqPattern.seqPatternSetOps[Set[Boolean]] // we can sometimes enumerate a finite LazyList of matches def enumerate(sp: SeqPattern[Set[Boolean]]): Option[LazyList[List[Boolean]]] = @@ -427,7 +466,7 @@ class BoolSeqPatternTest extends SeqPatternLaws[Set[Boolean], Boolean, List[Bool val rests = loop(rest) val heads = hsp match { case Lit(s) if s.size == 1 => s.head :: Nil - case _ => + case _ => // we assume any because there // are no wilds List(false, true) @@ -444,8 +483,14 @@ class BoolSeqPatternTest extends SeqPatternLaws[Set[Boolean], Boolean, List[Bool override def diffUBRegressions = List({ - val p1 = Cat(Wildcard,Empty) - val p2 = Cat(Lit(Set(true)),Cat(Wildcard,Cat(Lit(Set(true, false)),Cat(Lit(Set(true)),Cat(Wildcard, Empty))))) + val p1 = Cat(Wildcard, Empty) + val p2 = Cat( + Lit(Set(true)), + Cat( + Wildcard, + Cat(Lit(Set(true, false)), Cat(Lit(Set(true)), Cat(Wildcard, Empty))) + ) + ) val s = List(true, false, false) (p1, p2, s) @@ -458,7 +503,8 @@ class BoolSeqPatternTest extends SeqPatternLaws[Set[Boolean], Boolean, List[Bool SeqPattern.fromList(Wildcard :: Lit(Set(true)) :: Wildcard :: Nil), SeqPattern.fromList(Lit(Set(false)) :: Wildcard :: Nil), SeqPattern.fromList(Nil) - )) + ) + ) assert(missing == Nil) } @@ -477,8 +523,7 @@ class BoolSeqPatternTest extends SeqPatternLaws[Set[Boolean], Boolean, List[Bool assert(matches(p1, s), s"p1: $s") assert(matches(p2, s), s"p2: $s") } - } - else { + } else { // difference is an upper-bound // so truediff <= diff // if ms does not match diff, then it must @@ -539,26 +584,82 @@ class BoolSeqPatternTest extends SeqPatternLaws[Set[Boolean], Boolean, List[Bool val subsets: List[(Pattern, Pattern, Boolean)] = List( { - val p0 = Cat(Wildcard,Cat(Lit(Set(false)),Cat(Lit(Set(true, false)),Cat(Lit(Set(true, false)),Empty)))) - val p1 = Cat(Wildcard,Cat(Lit(Set(true, false)),Empty)) + val p0 = Cat( + Wildcard, + Cat( + Lit(Set(false)), + Cat(Lit(Set(true, false)), Cat(Lit(Set(true, false)), Empty)) + ) + ) + val p1 = Cat(Wildcard, Cat(Lit(Set(true, false)), Empty)) (p0, p1, true) } ) - subsets.foreach { case (p1, p2, res) => assert(setOps.subset(p1, p2) == res) } + subsets.foreach { case (p1, p2, res) => + assert(setOps.subset(p1, p2) == res) + } val regressions: List[(Pattern, Pattern, List[List[Boolean]])] = List( { - val p0 = Cat(Lit(Set(false)),Cat(Lit(Set(false)),Cat(Lit(Set(false)),Cat(Lit(Set(true)),Cat(AnyElem,Cat(Lit(Set(false)),Cat(Lit(Set(false)),Cat(Lit(Set(false)),Cat(Lit(Set(true, false)),Cat(Lit(Set(true)),Empty)))))))))) - val p1 = Cat(Wildcard,Cat(Lit(Set(true, false)),Cat(Lit(Set(true)),Cat(Lit(Set(false)),Cat(Lit(Set(true, false)),Cat(Wildcard,Cat(Lit(Set(true, false)),Cat(Lit(Set(true, false)),Empty)))))))) + val p0 = Cat( + Lit(Set(false)), + Cat( + Lit(Set(false)), + Cat( + Lit(Set(false)), + Cat( + Lit(Set(true)), + Cat( + AnyElem, + Cat( + Lit(Set(false)), + Cat( + Lit(Set(false)), + Cat( + Lit(Set(false)), + Cat(Lit(Set(true, false)), Cat(Lit(Set(true)), Empty)) + ) + ) + ) + ) + ) + ) + ) + ) + val p1 = Cat( + Wildcard, + Cat( + Lit(Set(true, false)), + Cat( + Lit(Set(true)), + Cat( + Lit(Set(false)), + Cat( + Lit(Set(true, false)), + Cat( + Wildcard, + Cat( + Lit(Set(true, false)), + Cat(Lit(Set(true, false)), Empty) + ) + ) + ) + ) + ) + ) + ) val matchp0 = Nil (p0, p1, matchp0) - }, - { - val p0 = Cat(Lit(Set(true)),Cat(AnyElem,Cat(Lit(Set(false)),Empty))) - val p1 = Cat(Wildcard,Cat(Lit(Set(true)),Cat(Lit(Set(false)),Cat(Wildcard,Empty)))) + }, { + val p0 = + Cat(Lit(Set(true)), Cat(AnyElem, Cat(Lit(Set(false)), Empty))) + val p1 = Cat( + Wildcard, + Cat(Lit(Set(true)), Cat(Lit(Set(false)), Cat(Wildcard, Empty))) + ) val matchp0 = Nil (p0, p1, matchp0) @@ -570,11 +671,23 @@ class BoolSeqPatternTest extends SeqPatternLaws[Set[Boolean], Boolean, List[Bool } test("test some missing branches") { - assert(setOps.missingBranches(Cat(Wildcard, Empty) :: Nil, Pattern.fromList(List(Wildcard, AnyElem, Wildcard)) :: Nil) == - Pattern.fromList(Nil) :: Nil) + assert( + setOps.missingBranches( + Cat(Wildcard, Empty) :: Nil, + Pattern.fromList(List(Wildcard, AnyElem, Wildcard)) :: Nil + ) == + Pattern.fromList(Nil) :: Nil + ) - assert(setOps.missingBranches(Cat(Wildcard, Empty) :: Nil, Pattern.fromList(List(Wildcard, Lit(Set(true)), Wildcard)) :: Nil) == - Pattern.fromList(Nil) :: Pattern.fromList(Lit(Set(false)) :: Wildcard :: Nil) :: Nil) + assert( + setOps.missingBranches( + Cat(Wildcard, Empty) :: Nil, + Pattern.fromList(List(Wildcard, Lit(Set(true)), Wildcard)) :: Nil + ) == + Pattern.fromList(Nil) :: Pattern.fromList( + Lit(Set(false)) :: Wildcard :: Nil + ) :: Nil + ) } } @@ -609,23 +722,30 @@ class SeqPatternTest extends SeqPatternLaws[Char, Char, String, Unit] { override def diffUBRegressions = List({ - val p1 = Cat(AnyElem,Cat(Wildcard,Empty)) - val p2 = Cat(Wildcard,Cat(AnyElem,Cat(Lit('0'),Cat(Lit('1'),Cat(Wildcard, Empty))))) + val p1 = Cat(AnyElem, Cat(Wildcard, Empty)) + val p2 = Cat( + Wildcard, + Cat(AnyElem, Cat(Lit('0'), Cat(Lit('1'), Cat(Wildcard, Empty)))) + ) (p1, p2, "11") }) test("some matching examples") { - val ms: List[(Pattern, String)] = + val ms: List[(Pattern, String)] = (Pattern.Wild + Pattern.Any + Pattern.Any + toPattern("1"), "111") :: - (toPattern("1") + Pattern.Any + toPattern("1"), "111") :: - Nil + (toPattern("1") + Pattern.Any + toPattern("1"), "111") :: + Nil ms.foreach { case (p, s) => assert(matches(p, s), s"matches($p, $s)") } } test("wildcard on either side is the same as contains") { forAll { (ps: String, s: String) => - assert(matches(Pattern.Wild + toPattern(ps) + Pattern.Wild, s) == s.contains(ps)) + assert( + matches(Pattern.Wild + toPattern(ps) + Pattern.Wild, s) == s.contains( + ps + ) + ) } } test("wildcard on front side is the same as endsWith") { @@ -639,18 +759,38 @@ class SeqPatternTest extends SeqPatternLaws[Char, Char, String, Unit] { } } - test("intersection(p1, p2).matches(x) == p1.matches(x) && p2.matches(x) regressions") { + test( + "intersection(p1, p2).matches(x) == p1.matches(x) && p2.matches(x) regressions" + ) { val regressions: List[(Pattern, Pattern, String)] = - (toPattern("0") + Pattern.Any + Pattern.Wild, Pattern.Any + toPattern("01") + Pattern.Any, "001") :: - (Pattern.Wild + Pattern.Any + Pattern.Any + toPattern("1"), toPattern("1") + Pattern.Any + toPattern("1"), "111") :: - Nil + ( + toPattern("0") + Pattern.Any + Pattern.Wild, + Pattern.Any + toPattern("01") + Pattern.Any, + "001" + ) :: + ( + Pattern.Wild + Pattern.Any + Pattern.Any + toPattern("1"), + toPattern("1") + Pattern.Any + toPattern("1"), + "111" + ) :: + Nil regressions.foreach { case (p1, p2, s) => intersectionMatchLaw(p1, p2, s) } } test("subset is consistent with match regressions") { - assert(setOps.subset(toPattern("00") + Pattern.Wild, toPattern("0") + Pattern.Wild)) - assert(setOps.subset(toPattern("00") + Pattern.Any + Pattern.Wild, toPattern("0") + Pattern.Any + Pattern.Wild)) + assert( + setOps.subset( + toPattern("00") + Pattern.Wild, + toPattern("0") + Pattern.Wild + ) + ) + assert( + setOps.subset( + toPattern("00") + Pattern.Any + Pattern.Wild, + toPattern("0") + Pattern.Any + Pattern.Wild + ) + ) } test("if y - x is empty, (yz - xz) for all strings is empty") { @@ -667,13 +807,14 @@ class SeqPatternTest extends SeqPatternLaws[Char, Char, String, Unit] { } test("if y - x is empty, (zy - zx) for all strings is empty") { - forAll(genPattern, genPattern, genSeq) { (x0: Pattern, y0: Pattern, str: String) => - val x = Pattern.fromList(x0.toList.take(5)) - val y = Pattern.fromList(y0.toList.take(5)) - if (setOps.difference(y, x) == Nil) { - val left = toPattern(str) + y - assert(setOps.difference(left, toPattern(str) + x) == Nil) - } + forAll(genPattern, genPattern, genSeq) { + (x0: Pattern, y0: Pattern, str: String) => + val x = Pattern.fromList(x0.toList.take(5)) + val y = Pattern.fromList(y0.toList.take(5)) + if (setOps.difference(y, x) == Nil) { + val left = toPattern(str) + y + assert(setOps.difference(left, toPattern(str) + x) == Nil) + } } } @@ -682,7 +823,7 @@ class SeqPatternTest extends SeqPatternLaws[Char, Char, String, Unit] { namedMatch(n, str).foreach { m => n.render(m)(_.toString) match { case Some(s0) => assert(s0 == str, s"m = $m") - case None => + case None => // this can only happen if we have unnamed Wild/AnyElem assert(n.isRenderable == false, s"m = $m") } @@ -695,8 +836,11 @@ class SeqPatternTest extends SeqPatternLaws[Char, Char, String, Unit] { import SeqPart.Wildcard val regressions: List[(Named, String)] = - (NCat(Bind("0",NSeqPart(Wildcard)),Bind("1",NSeqPart(Wildcard))), "") :: - Nil + ( + NCat(Bind("0", NSeqPart(Wildcard)), Bind("1", NSeqPart(Wildcard))), + "" + ) :: + Nil regressions.foreach { case (n, s) => law(n, s) } } @@ -708,17 +852,22 @@ class SeqPatternTest extends SeqPatternLaws[Char, Char, String, Unit] { import Named._ val regressions: List[(Named, String)] = - (NCat(NEmpty,NCat(NSeqPart(Lit('1')),NSeqPart(Wildcard))), "1") :: - (NCat(NSeqPart(Lit('1')),NSeqPart(Wildcard)), "1") :: - (NSeqPart(Lit('1')), "1") :: - Nil + (NCat(NEmpty, NCat(NSeqPart(Lit('1')), NSeqPart(Wildcard))), "1") :: + (NCat(NSeqPart(Lit('1')), NSeqPart(Wildcard)), "1") :: + (NSeqPart(Lit('1')), "1") :: + Nil regressions.foreach { case (n, s) => namedMatchesPatternLaw(n, s) } } test("Test some examples of Named matching") { // foo@("bar" baz@(*"baz")) - val p1 = (named("bar") + (Named.Wild + named("baz")).name("baz")).name("foo") - assert(namedMatch(p1, "bar and baz") == Some(Map("foo" -> "bar and baz", "baz" -> " and baz"))) + val p1 = + (named("bar") + (Named.Wild + named("baz")).name("baz")).name("foo") + assert( + namedMatch(p1, "bar and baz") == Some( + Map("foo" -> "bar and baz", "baz" -> " and baz") + ) + ) } } diff --git a/core/src/test/scala/org/bykn/bosatsu/pattern/SetOpsLaws.scala b/core/src/test/scala/org/bykn/bosatsu/pattern/SetOpsLaws.scala index 4440f0d44..555b37e00 100644 --- a/core/src/test/scala/org/bykn/bosatsu/pattern/SetOpsLaws.scala +++ b/core/src/test/scala/org/bykn/bosatsu/pattern/SetOpsLaws.scala @@ -2,7 +2,10 @@ package org.bykn.bosatsu.pattern import cats.Eq import org.scalacheck.{Arbitrary, Cogen, Gen} -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.scalatest.funsuite.AnyFunSuite abstract class SetOpsLaws[A] extends AnyFunSuite { @@ -39,7 +42,6 @@ abstract class SetOpsLaws[A] extends AnyFunSuite { } } - test("intersection is commutative") { forAll(genItem, genItem, eqUnion)(intersectionIsCommutative(_, _, _)) } @@ -60,7 +62,9 @@ abstract class SetOpsLaws[A] extends AnyFunSuite { } test("intersection is associative") { - forAll(genItem, genItem, genItem, eqUnion)(intersectionIsAssociative(_, _, _, _)) + forAll(genItem, genItem, genItem, eqUnion)( + intersectionIsAssociative(_, _, _, _) + ) } test("unify union makes size <= input") { @@ -106,14 +110,14 @@ abstract class SetOpsLaws[A] extends AnyFunSuite { } test("if a n b = 0 then a - b = a") { - // difference is an upper bound, so this is not true - // although we wish it were - /* + // difference is an upper bound, so this is not true + // although we wish it were + /* if (diff.map(_.normalize).distinct == p1.normalize :: Nil) { // intersection is 0 assert(inter == Nil) } - */ + */ forAll(genItem, genItem, eqUnion)(emptyIntersectionMeansDiffIdent(_, _, _)) } @@ -201,8 +205,7 @@ abstract class SetOpsLaws[A] extends AnyFunSuite { // should be in that case, if (a - b) = a, then // clearly we expect (a n c) == (a n c) - (b n c) // so, b n c has to not intersect with a, but it might - } - else if (isTop(a) && intBC.isEmpty) { + } else if (isTop(a) && intBC.isEmpty) { // in patterns, we "cast" ill-typed comparisions // since we can don't care about cases that don't // type-check. But this can make this law fail: @@ -212,8 +215,7 @@ abstract class SetOpsLaws[A] extends AnyFunSuite { // but (_ n c) = c, and b n c = 0 val leftEqC = differenceAll(unifyUnion(left), c :: Nil).isEmpty assert((left == Nil) || leftEqC) - } - else { + } else { val intAC = intersection(a, c) val right = differenceAll(intAC, intBC) @@ -222,10 +224,12 @@ abstract class SetOpsLaws[A] extends AnyFunSuite { val leftu = unifyUnion(left) if (leftu == unifyUnion(intAC)) { succeed - } - else { + } else { val rightu = unifyUnion(right) - assert(leftu == rightu, s"diffAB = $diffab, intAC = $intAC, intBC = $intBC") + assert( + leftu == rightu, + s"diffAB = $diffab, intAC = $intAC, intBC = $intBC" + ) } } } @@ -237,9 +241,11 @@ abstract class SetOpsLaws[A] extends AnyFunSuite { test("(a - b) n c = (a n c) - (b n c)") { forAll(genItem, genItem, genItem)(diffIntersectionLaw(_, _, _)) } - */ + */ - test("missing branches, if added are total and none of the missing are unreachable") { + test( + "missing branches, if added are total and none of the missing are unreachable" + ) { def law(top: A, pats: List[A]) = { @@ -247,9 +253,11 @@ abstract class SetOpsLaws[A] extends AnyFunSuite { val rest1 = missingBranches(top :: Nil, pats ::: rest) if (rest1.isEmpty) { val unreach = unreachableBranches(pats ::: rest) - assert(unreach.filter(rest.toSet) == Nil, s"\n\nrest = ${rest}\n\ninit: ${pats}") - } - else { + assert( + unreach.filter(rest.toSet) == Nil, + s"\n\nrest = ${rest}\n\ninit: ${pats}" + ) + } else { fail(s"after adding ${rest} we still need ${rest1}") } } @@ -283,9 +291,11 @@ class DistinctSetOpsTest extends SetOpsLaws[Byte] { class IMapSetOpsTest extends SetOpsLaws[Byte] { val setOps: SetOps[Byte] = - SetOps.imap(SetOps.distinct[Byte], - { (b: Byte) => (b ^ 0xFF).toByte }, - { (b: Byte) => (b ^ 0xFF).toByte }) + SetOps.imap( + SetOps.distinct[Byte], + { (b: Byte) => (b ^ 0xff).toByte }, + { (b: Byte) => (b ^ 0xff).toByte } + ) val genItem: Gen[Byte] = Gen.choose(Byte.MinValue, Byte.MaxValue) @@ -296,7 +306,8 @@ class IMapSetOpsTest extends SetOpsLaws[Byte] { } class ProductSetOpsTest extends SetOpsLaws[(Boolean, Boolean)] { - val setOps: SetOps[(Boolean, Boolean)] = SetOps.product(SetOps.distinct[Boolean], SetOps.distinct[Boolean]) + val setOps: SetOps[(Boolean, Boolean)] = + SetOps.product(SetOps.distinct[Boolean], SetOps.distinct[Boolean]) val genItem: Gen[(Boolean, Boolean)] = Gen.oneOf((false, false), (false, true), (true, false), (true, true)) @@ -319,7 +330,6 @@ class UnitSetOpsTest extends SetOpsLaws[Unit] { }) } - case class Predicate[A](toFn: A => Boolean) { self => def apply(a: A): Boolean = toFn(a) def &&(that: Predicate[A]): Predicate[A] = @@ -343,44 +353,49 @@ object Predicate { Arbitrary(genPred[A]) } - class SetOpsTests extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 50000) - //PropertyCheckConfiguration(minSuccessful = 5000) + // PropertyCheckConfiguration(minSuccessful = 50000) + // PropertyCheckConfiguration(minSuccessful = 5000) PropertyCheckConfiguration(minSuccessful = 500) test("allPerms is correct") { - forAll(Gen.choose(0, 6).flatMap(Gen.listOfN(_, Arbitrary.arbitrary[Int]))) { is0 => - // make everything distinct - val is = is0.zipWithIndex - val perms = SetOps.allPerms(is) + forAll(Gen.choose(0, 6).flatMap(Gen.listOfN(_, Arbitrary.arbitrary[Int]))) { + is0 => + // make everything distinct + val is = is0.zipWithIndex + val perms = SetOps.allPerms(is) - def fact(i: Int, acc: Int): Int = - if (i <= 1) acc - else fact(i - 1, i * acc) + def fact(i: Int, acc: Int): Int = + if (i <= 1) acc + else fact(i - 1, i * acc) - assert(perms.length == fact(is0.size, 1)) + assert(perms.length == fact(is0.size, 1)) - perms.foreach { p => - assert(p.sorted == is.sorted) - } - val pi = perms.zipWithIndex + perms.foreach { p => + assert(p.sorted == is.sorted) + } + val pi = perms.zipWithIndex - for { - (p1, i1) <- pi - (p2, i2) <- pi - } assert((i1 >= i2 || (p1 != p2))) + for { + (p1, i1) <- pi + (p2, i2) <- pi + } assert((i1 >= i2 || (p1 != p2))) } } - test("greedySearch finds the optimal path if lookahead is greater than size") { + test( + "greedySearch finds the optimal path if lookahead is greater than size" + ) { // we need a non-commutative operation to test this // use 2x2 matrix multiplication - def mult(left: Vector[Vector[Double]], right: Vector[Vector[Double]]): Vector[Vector[Double]] = { + def mult( + left: Vector[Vector[Double]], + right: Vector[Vector[Double]] + ): Vector[Vector[Double]] = { def dot(v1: Vector[Double], v2: Vector[Double]) = - v1.iterator.zip(v2.iterator).map { case (a, b) => a*b }.sum + v1.iterator.zip(v2.iterator).map { case (a, b) => a * b }.sum def trans(v1: Vector[Vector[Double]]) = Vector(Vector(v1(0)(0), v1(1)(0)), Vector(v1(0)(1), v1(1)(1))) @@ -392,12 +407,13 @@ class SetOpsTests extends AnyFunSuite { (c, ci) <- trans(right).zipWithIndex } yield ((ri, ci), dot(r, c)) - data.foldLeft(res) { case (v, ((r, c), d)) => v.updated(r, v(r).updated(c, d)) } + data.foldLeft(res) { case (v, ((r, c), d)) => + v.updated(r, v(r).updated(c, d)) + } } def norm(left: Vector[Vector[Double]]): Double = - left.map(_.map { x => x*x }.sum).sum - + left.map(_.map { x => x * x }.sum).sum val genMat: Gen[Vector[Vector[Double]]] = { val elem = Gen.choose(-1.0, 1.0) @@ -410,7 +426,9 @@ class SetOpsTests extends AnyFunSuite { } forAll(genMat, Gen.listOfN(5, genMat)) { (v0, prods) => - val res = SetOps.greedySearch(5, v0, prods)({(v, ps) => ps.foldLeft(v)(mult(_, _))})(norm(_)) + val res = SetOps.greedySearch(5, v0, prods)({ (v, ps) => + ps.foldLeft(v)(mult(_, _)) + })(norm(_)) val normRes = norm(res) val naive = norm(prods.foldLeft(v0)(mult(_, _))) assert(normRes <= naive) @@ -427,7 +445,10 @@ class SetOpsTests extends AnyFunSuite { val bb = pb(b) val bc = pc(b) if (!right(b)) { - assert(!left(b), s"ba = $ba, bb = $bb, bc = $bc, ${left(b)} != ${right(b)}") + assert( + !left(b), + s"ba = $ba, bb = $bb, bc = $bc, ${left(b)} != ${right(b)}" + ) } } } @@ -443,7 +464,10 @@ class SetOpsTests extends AnyFunSuite { val bb = pb(b) val bc = pc(b) if (left(b)) { - assert(right(b), s"ba = $ba, bb = $bb, bc = $bc, ${left(b)} != ${right(b)}") + assert( + right(b), + s"ba = $ba, bb = $bb, bc = $bc, ${left(b)} != ${right(b)}" + ) } } } @@ -458,7 +482,10 @@ class SetOpsTests extends AnyFunSuite { val ba = pa(b) val bb = pb(b) val bc = pc(b) - assert(left(b) == right(b), s"ba = $ba, bb = $bb, bc = $bc, ${left(b)} != ${right(b)}") + assert( + left(b) == right(b), + s"ba = $ba, bb = $bb, bc = $bc, ${left(b)} != ${right(b)}" + ) } } } @@ -473,18 +500,28 @@ class SetOpsTests extends AnyFunSuite { val bb = pb(b) val bc = pc(b) if (!right(b)) { - assert(!left(b), s"ba = $ba, bb = $bb, bc = $bc, ${left(b)} != ${right(b)}") + assert( + !left(b), + s"ba = $ba, bb = $bb, bc = $bc, ${left(b)} != ${right(b)}" + ) } } } } test("A1 x B1 - A2 x B2 = (A1 n A2)x(B1 - B2) u (A1 - A2)xB1") { - forAll { (a1: Predicate[Byte], a2: Predicate[Byte], b1: Predicate[Byte], b2: Predicate[Byte], checks: List[(Byte, Byte)]) => - val left = a1.product(b1) - a2.product(b2) - val right = (a1 && a2).product(b1 - b2) || (a1 - a2).product(b1) - checks.foreach { ab => - assert(left(ab) == right(ab)) - } + forAll { + ( + a1: Predicate[Byte], + a2: Predicate[Byte], + b1: Predicate[Byte], + b2: Predicate[Byte], + checks: List[(Byte, Byte)] + ) => + val left = a1.product(b1) - a2.product(b2) + val right = (a1 && a2).product(b1 - b2) || (a1 - a2).product(b1) + checks.foreach { ab => + assert(left(ab) == right(ab)) + } } } } diff --git a/core/src/test/scala/org/bykn/bosatsu/pattern/StringSeqPatternSetLaws.scala b/core/src/test/scala/org/bykn/bosatsu/pattern/StringSeqPatternSetLaws.scala index d1e99a276..e99595d1c 100644 --- a/core/src/test/scala/org/bykn/bosatsu/pattern/StringSeqPatternSetLaws.scala +++ b/core/src/test/scala/org/bykn/bosatsu/pattern/StringSeqPatternSetLaws.scala @@ -11,9 +11,9 @@ class StringSeqPatternSetLaws extends SetOpsLaws[SeqPattern[Char]] { val Pattern = SeqPattern implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 50000) + // PropertyCheckConfiguration(minSuccessful = 50000) PropertyCheckConfiguration(minSuccessful = 5000) - //PropertyCheckConfiguration(minSuccessful = 5) + // PropertyCheckConfiguration(minSuccessful = 5) // if there are too many wildcards the intersections will blow up def genItem: Gen[Pattern] = StringSeqPatternGen.genPat.map { p => @@ -45,22 +45,31 @@ class StringSeqPatternSetLaws extends SetOpsLaws[SeqPattern[Char]] { import SeqPart.Lit val regressions: List[(SeqPattern[Char], SeqPattern[Char])] = - (Cat(Lit('1'),Cat(Lit('1'),Cat(Lit('1'),Cat(Lit('1'),Empty)))), - Cat(Lit('0'),Cat(Lit('1'),Cat(Lit('1'),Empty)))) :: - (Cat(Lit('1'),Cat(Lit('0'),Cat(Lit('1'),Cat(Lit('0'),Empty)))), - Cat(Lit('0'),Cat(Lit('1'),Empty))) :: - Nil - - regressions.foreach { case (a, b) => subsetConsistencyLaw(a, b, Eq.fromUniversalEquals) } + ( + Cat(Lit('1'), Cat(Lit('1'), Cat(Lit('1'), Cat(Lit('1'), Empty)))), + Cat(Lit('0'), Cat(Lit('1'), Cat(Lit('1'), Empty))) + ) :: + ( + Cat(Lit('1'), Cat(Lit('0'), Cat(Lit('1'), Cat(Lit('0'), Empty)))), + Cat(Lit('0'), Cat(Lit('1'), Empty)) + ) :: + Nil + + regressions.foreach { case (a, b) => + subsetConsistencyLaw(a, b, Eq.fromUniversalEquals) + } } test("*x* problems") { import SeqPattern.{Cat, Empty} import SeqPart.{Lit, Wildcard} - val x = Cat(Wildcard,Cat(Lit('q'),Cat(Wildcard,Cat(Lit('p'),Cat(Wildcard,Empty))))) - val y = Cat(Wildcard,Cat(Lit('p'),Cat(Wildcard,Empty))) - val z = Cat(Wildcard,Cat(Lit('q'),Cat(Wildcard,Empty))) + val x = Cat( + Wildcard, + Cat(Lit('q'), Cat(Wildcard, Cat(Lit('p'), Cat(Wildcard, Empty)))) + ) + val y = Cat(Wildcard, Cat(Lit('p'), Cat(Wildcard, Empty))) + val z = Cat(Wildcard, Cat(Lit('q'), Cat(Wildcard, Empty))) // note y and z are clearly bigger than x because they are prefix/suffix that end/start with // Wildcard assert(setOps.difference(x, y).isEmpty) @@ -71,27 +80,34 @@ class StringSeqPatternSetLaws extends SetOpsLaws[SeqPattern[Char]] { import SeqPattern.{Cat, Empty} import SeqPart.{AnyElem, Lit, Wildcard} - val regressions: List[(SeqPattern[Char], SeqPattern[Char], SeqPattern[Char])] = - (Cat(Wildcard, Empty), - Cat(AnyElem,Cat(Lit('1'),Cat(AnyElem,Empty))), - Cat(AnyElem,Cat(Lit('1'),Cat(Lit('0'),Empty)))) :: - (Cat(Wildcard,Cat(Lit('0'),Empty)), - Cat(AnyElem,Cat(Lit('1'),Cat(AnyElem,Cat(Lit('0'),Empty)))), - Cat(AnyElem,Cat(Lit('1'),Cat(Lit('0'),Cat(Lit('0'),Empty))))) :: - (Cat(Wildcard, Cat(Lit('q'), Cat(Wildcard, Empty))), + val regressions + : List[(SeqPattern[Char], SeqPattern[Char], SeqPattern[Char])] = + ( Cat(Wildcard, Empty), - Cat(Wildcard, Cat(Lit('p'), Cat(Wildcard, Empty)))) :: - /* - * This fails currently - * see: https://github.com/johnynek/bosatsu/issues/486 + Cat(AnyElem, Cat(Lit('1'), Cat(AnyElem, Empty))), + Cat(AnyElem, Cat(Lit('1'), Cat(Lit('0'), Empty))) + ) :: + ( + Cat(Wildcard, Cat(Lit('0'), Empty)), + Cat(AnyElem, Cat(Lit('1'), Cat(AnyElem, Cat(Lit('0'), Empty)))), + Cat(AnyElem, Cat(Lit('1'), Cat(Lit('0'), Cat(Lit('0'), Empty)))) + ) :: + ( + Cat(Wildcard, Cat(Lit('q'), Cat(Wildcard, Empty))), + Cat(Wildcard, Empty), + Cat(Wildcard, Cat(Lit('p'), Cat(Wildcard, Empty))) + ) :: + /* + * This fails currently + * see: https://github.com/johnynek/bosatsu/issues/486 { val p1 = Cat(Wildcard,Cat(Lit('1'),Cat(Lit('0'),Cat(Lit('0'),Empty)))) val p2 = Cat(AnyElem,Cat(Lit('1'),Cat(Wildcard,Cat(Lit('0'),Empty)))) val p3 = Cat(Lit('1'),Cat(Lit('1'),Cat(Wildcard,Cat(Lit('0'),Empty)))) (p1, p2, p3) } :: - */ - Nil + */ + Nil regressions.foreach { case (a, b, c) => diffIntersectionLaw(a, b, c) } } diff --git a/core/src/test/scala/org/bykn/bosatsu/rankn/NTypeGen.scala b/core/src/test/scala/org/bykn/bosatsu/rankn/NTypeGen.scala index 74f7b800e..e53a37e81 100644 --- a/core/src/test/scala/org/bykn/bosatsu/rankn/NTypeGen.scala +++ b/core/src/test/scala/org/bykn/bosatsu/rankn/NTypeGen.scala @@ -24,16 +24,29 @@ object NTypeGen { consIdentGen.map(TypeName(_)) val keyWords = Set( - "if", "ffi", "match", "struct", "enum", "else", "elif", - "def", "external", "package", "import", "export", "forall", - "recur", "recursive") + "if", + "ffi", + "match", + "struct", + "enum", + "else", + "elif", + "def", + "external", + "package", + "import", + "export", + "forall", + "recur", + "recursive" + ) val lowerIdent: Gen[String] = (for { c <- lower cnt <- Gen.choose(0, 10) rest <- Gen.listOfN(cnt, identC) - } yield (c :: rest).mkString).filter { s=> !keyWords(s) } + } yield (c :: rest).mkString).filter { s => !keyWords(s) } val packageNameGen: Gen[PackageName] = for { @@ -45,7 +58,8 @@ object NTypeGen { } yield PackageName(NonEmptyList(h, tail)) val genConst: Gen[Type.Const] = - Gen.zip(packageNameGen, typeNameGen) + Gen + .zip(packageNameGen, typeNameGen) .map { case (p, n) => Type.Const.Defined(p, n) } val genBound: Gen[Type.Var.Bound] = @@ -112,21 +126,31 @@ object NTypeGen { // either Unit, TupleConsType(a, tuple) Gen.oneOf( Gen.const(UnitType), - Gen.zip(recurse, recTup).map { case (h, t) => Type.TyApply(Type.TyApply(TupleConsType, h), t) }) + Gen.zip(recurse, recTup).map { case (h, t) => + Type.TyApply(Type.TyApply(TupleConsType, h), t) + } + ) } Gen.frequency( (6, Gen.oneOf(t0)), - (2, for { - cons <- Gen.oneOf(t1) - param <- recurse - } yield TyApply(cons, param)), + ( + 2, + for { + cons <- Gen.oneOf(t1) + param <- recurse + } yield TyApply(cons, param) + ), (1, tupleTypes), - (1, for { - cons <- Gen.oneOf(t2) - param1 <- recurse - param2 <- recurse - } yield TyApply(TyApply(cons, param1), param2))) + ( + 1, + for { + cons <- Gen.oneOf(t2) + param1 <- recurse + param2 <- recurse + } yield TyApply(TyApply(cons, param1), param2) + ) + ) } def genDepth(d: Int, genC: Option[Gen[Type.Const]]): Gen[Type] = @@ -141,11 +165,12 @@ object NTypeGen { in <- recurse } yield Type.forAll(as, in) - val genApply = Gen.zip(recurse, recurse).map { case (a, b) => Type.TyApply(a, b) } + val genApply = + Gen.zip(recurse, recurse).map { case (a, b) => Type.TyApply(a, b) } Gen.oneOf(recurse, genApply, genForAll) } - - val genDepth03: Gen[Type] = Gen.choose(0, 3).flatMap(genDepth(_, Some(genConst))) + val genDepth03: Gen[Type] = + Gen.choose(0, 3).flatMap(genDepth(_, Some(genConst))) } diff --git a/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala b/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala index 0d3e20ed1..f7c2cc03d 100644 --- a/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala @@ -17,17 +17,20 @@ class RankNInferTest extends AnyFunSuite { val emptyRegion: Region = Region(0, 0) - implicit val unitRegion: HasRegion[Unit] = HasRegion.instance(_ => emptyRegion) + implicit val unitRegion: HasRegion[Unit] = + HasRegion.instance(_ => emptyRegion) private def strToConst(str: Identifier.Constructor): Type.Const = str.asString match { - case "Int" => Type.Const.predef("Int") + case "Int" => Type.Const.predef("Int") case "String" => Type.Const.predef("String") - case "List" => Type.Const.predef("List") - case _ => Type.Const.Defined(testPackage, TypeName(str)) + case "List" => Type.Const.predef("List") + case _ => Type.Const.Defined(testPackage, TypeName(str)) } - def asFullyQualified(ns: Iterable[(Identifier, Type)]): Map[Infer.Name, Type] = + def asFullyQualified( + ns: Iterable[(Identifier, Type)] + ): Map[Infer.Name, Type] = ns.iterator.map { case (n, t) => ((Some(testPackage), n), t) }.toMap def typeFrom(str: String): Type = { @@ -39,7 +42,8 @@ class RankNInferTest extends AnyFunSuite { val t1 = typeFrom(left) val t2 = typeFrom(right) - Infer.substitutionCheck(t1, t2, emptyRegion, emptyRegion) + Infer + .substitutionCheck(t1, t2, emptyRegion, emptyRegion) .runFully(Map.empty, Map.empty, Type.builtInKinds) } @@ -49,7 +53,10 @@ class RankNInferTest extends AnyFunSuite { } def assertTypesDisjoint(left: String, right: String) = - assert(runUnify(left, right).isLeft, s"$left unexpectedly unifies with $right") + assert( + runUnify(left, right).isLeft, + s"$left unexpectedly unifies with $right" + ) def defType(n: String): Type.Const.Defined = Type.Const.Defined(testPackage, TypeName(Identifier.Constructor(n))) @@ -59,31 +66,53 @@ class RankNInferTest extends AnyFunSuite { val withBools: Map[Infer.Name, Type] = Map( - (Some(PackageName.PredefName), Identifier.unsafe("True")) -> Type.BoolType, - (Some(PackageName.PredefName), Identifier.unsafe("False")) -> Type.BoolType) + ( + Some(PackageName.PredefName), + Identifier.unsafe("True") + ) -> Type.BoolType, + ( + Some(PackageName.PredefName), + Identifier.unsafe("False") + ) -> Type.BoolType + ) val boolTypes: Map[(PackageName, Constructor), Infer.Cons] = Map( - ((PackageName.PredefName, Constructor("True")), (Nil, Nil, Type.Const.predef("Bool"))), - ((PackageName.PredefName, Constructor("False")), (Nil, Nil, Type.Const.predef("Bool")))) + ( + (PackageName.PredefName, Constructor("True")), + (Nil, Nil, Type.Const.predef("Bool")) + ), + ( + (PackageName.PredefName, Constructor("False")), + (Nil, Nil, Type.Const.predef("Bool")) + ) + ) def testType[A: HasRegion](term: Expr[A], ty: Type) = Infer.typeCheck(term).runFully(withBools, boolTypes, Map.empty) match { - case Left(err) => assert(false, err) + case Left(err) => assert(false, err) case Right(tpe) => assert(tpe.getType == ty, term.toString) } def testLetTypes[A: HasRegion](terms: List[(String, Expr[A], Type)]) = - Infer.typeCheckLets(testPackage, terms.map { case (k, v, _) => (Identifier.Name(k), RecursionKind.NonRecursive, v) }) + Infer + .typeCheckLets( + testPackage, + terms.map { case (k, v, _) => + (Identifier.Name(k), RecursionKind.NonRecursive, v) + } + ) .runFully(withBools, boolTypes, Type.builtInKinds) match { - case Left(err) => assert(false, err) - case Right(tpes) => - assert(tpes.size == terms.size) - terms.zip(tpes).foreach { case ((n, exp, expt), (n1, _, te)) => - assert(n == n1.asString, s"the name changed: $n != $n1") - assert(te.getType == expt, s"$n = $exp failed to typecheck to $expt, got ${te.getType}") - } - } - + case Left(err) => assert(false, err) + case Right(tpes) => + assert(tpes.size == terms.size) + terms.zip(tpes).foreach { case ((n, exp, expt), (n1, _, te)) => + assert(n == n1.asString, s"the name changed: $n != $n1") + assert( + te.getType == expt, + s"$n = $exp failed to typecheck to $expt, got ${te.getType}" + ) + } + } def lit(i: Int): Expr[Unit] = Literal(Lit(i.toLong), ()) def lit(b: Boolean): Expr[Unit] = @@ -95,36 +124,41 @@ class RankNInferTest extends AnyFunSuite { Lambda(NonEmptyList.one((Identifier.Name(arg), None)), result, ()) def v(name: String): Expr[Unit] = Identifier.unsafe(name) match { - case c@Identifier.Constructor(_) => Global(testPackage, c, ()) - case b: Identifier.Bindable => Local(b, ()) + case c @ Identifier.Constructor(_) => Global(testPackage, c, ()) + case b: Identifier.Bindable => Local(b, ()) } def ann(expr: Expr[Unit], t: Type): Expr[Unit] = Annotation(expr, t, ()) - def app(fn: Expr[Unit], arg: Expr[Unit]): Expr[Unit] = App(fn, NonEmptyList.one(arg), ()) + def app(fn: Expr[Unit], arg: Expr[Unit]): Expr[Unit] = + App(fn, NonEmptyList.one(arg), ()) def alam(arg: String, tpe: Type, res: Expr[Unit]): Expr[Unit] = Lambda(NonEmptyList.one((Identifier.Name(arg), Some(tpe))), res, ()) - def ife(cond: Expr[Unit], ift: Expr[Unit], iff: Expr[Unit]): Expr[Unit] = Expr.ifExpr(cond, ift, iff, ()) - def matche(arg: Expr[Unit], branches: NonEmptyList[(Pattern[String, Type], Expr[Unit])]): Expr[Unit] = - Match(arg, + def ife(cond: Expr[Unit], ift: Expr[Unit], iff: Expr[Unit]): Expr[Unit] = + Expr.ifExpr(cond, ift, iff, ()) + def matche( + arg: Expr[Unit], + branches: NonEmptyList[(Pattern[String, Type], Expr[Unit])] + ): Expr[Unit] = + Match( + arg, branches.map { case (p, e) => val p1 = p.mapName { n => (testPackage, Constructor(n)) } (p1, e) }, - ()) + () + ) - /** - * Check that a no import program has a given type - */ + /** Check that a no import program has a given type + */ def parseProgram(statement: String, tpe: String) = checkLast(statement) { te0 => - val te = te0 // TypedExprNormalization.normalize(te0).getOrElse(te0) te.traverseType[cats.Id] { - case t@Type.TyVar(Type.Var.Skolem(_, _, _)) => + case t @ Type.TyVar(Type.Var.Skolem(_, _, _)) => fail(s"illegate skolem ($t) escape in $te") t - case t@Type.TyMeta(_) => + case t @ Type.TyMeta(_) => fail(s"illegate meta ($t) escape in $te") t case good => @@ -134,10 +168,12 @@ class RankNInferTest extends AnyFunSuite { val rendered = te.repr val tp = te.getType lazy val teStr = Type.fullyResolvedDocument.document(tp).render(80) - assert(Type.freeTyVars(tp :: Nil).isEmpty, s"illegal inferred type: $teStr, in: $rendered") + assert( + Type.freeTyVars(tp :: Nil).isEmpty, + s"illegal inferred type: $teStr, in: $rendered" + ) - assert(Type.metaTvs(tp :: Nil).isEmpty, - s"illegal inferred type: $teStr") + assert(Type.metaTvs(tp :: Nil).isEmpty, s"illegal inferred type: $teStr") assert(te.getType.sameAs(typeFrom(tpe))) } @@ -145,9 +181,8 @@ class RankNInferTest extends AnyFunSuite { def checkTERepr(statement: String, repr: String) = checkLast(statement) { te => assert(te.repr == repr) } - /** - * Test that a program is ill-typed - */ + /** Test that a program is ill-typed + */ def parseProgramIllTyped(statement: String) = { val stmts = Parser.unsafeParse(Statement.parser, statement) Package.inferBody(testPackage, Nil, stmts) match { @@ -163,7 +198,10 @@ class RankNInferTest extends AnyFunSuite { assertTypesUnify("forall a, b. a -> b", "forall b. b -> Int") assertTypesUnify("forall a, b. a -> b", "forall a. a -> (forall b. b -> b)") assertTypesUnify("(forall a. a) -> Int", "(forall a. a) -> Int") - assertTypesUnify("(forall a. a -> Int) -> Int", "(forall a. a -> Int) -> Int") + assertTypesUnify( + "(forall a. a -> Int) -> Int", + "(forall a. a -> Int) -> Int" + ) assertTypesUnify("forall a, b. a -> b -> b", "forall a. a -> a -> a") // these aren't disjoint but the right is more polymorphic assertTypesDisjoint("forall a. a -> a -> a", "forall a, b. a -> b -> b") @@ -173,7 +211,7 @@ class RankNInferTest extends AnyFunSuite { assertTypesUnify("forall a, f: * -> *. f[a]", "forall x. List[x]") assertTypesUnify("forall a, f: +* -> *. f[a]", "forall x. List[x]") assertTypesDisjoint("forall a, f: -* -> *. f[a]", "forall x. List[x]") - //assertTypesUnify("(forall a, b. a -> b)[x, y]", "z -> w") + // assertTypesUnify("(forall a, b. a -> b)[x, y]", "z -> w") assertTypesDisjoint("Int", "String") assertTypesDisjoint("Int -> Unit", "String") @@ -185,16 +223,29 @@ class RankNInferTest extends AnyFunSuite { testType(lit(100), Type.IntType) testType(let("x", lambda("y", v("y")), lit(100)), Type.IntType) - testType(lambda("y", v("y")), - ForAll(NonEmptyList.of(b("a")), - Type.Fun(Type.TyVar(Bound("a")),Type.TyVar(Bound("a"))))) - testType(lambda("y", lambda("z", v("y"))), - ForAll(NonEmptyList.of(b("a"), b("b")), - Type.Fun(Type.TyVar(Bound("a")), - Type.Fun(Type.TyVar(Bound("b")),Type.TyVar(Bound("a")))))) + testType( + lambda("y", v("y")), + ForAll( + NonEmptyList.of(b("a")), + Type.Fun(Type.TyVar(Bound("a")), Type.TyVar(Bound("a"))) + ) + ) + testType( + lambda("y", lambda("z", v("y"))), + ForAll( + NonEmptyList.of(b("a"), b("b")), + Type.Fun( + Type.TyVar(Bound("a")), + Type.Fun(Type.TyVar(Bound("b")), Type.TyVar(Bound("a"))) + ) + ) + ) testType(app(lambda("x", v("x")), lit(100)), Type.IntType) - testType(ann(app(lambda("x", v("x")), lit(100)), Type.IntType), Type.IntType) + testType( + ann(app(lambda("x", v("x")), lit(100)), Type.IntType), + Type.IntType + ) testType(app(alam("x", Type.IntType, v("x")), lit(100)), Type.IntType) // test branches @@ -202,37 +253,58 @@ class RankNInferTest extends AnyFunSuite { testType(let("x", lit(0), ife(lit(true), v("x"), lit(1))), Type.IntType) val identFnType = - ForAll(NonEmptyList.of(b("a")), - Type.Fun(Type.TyVar(Bound("a")), Type.TyVar(Bound("a")))) - testType(let("x", lambda("y", v("y")), - ife(lit(true), v("x"), - ann(lambda("x", v("x")), identFnType))), identFnType) + ForAll( + NonEmptyList.of(b("a")), + Type.Fun(Type.TyVar(Bound("a")), Type.TyVar(Bound("a"))) + ) + testType( + let( + "x", + lambda("y", v("y")), + ife(lit(true), v("x"), ann(lambda("x", v("x")), identFnType)) + ), + identFnType + ) // test some lets testLetTypes( List( ("x", lit(100), Type.IntType), - ("y", Expr.Global(testPackage, Identifier.Name("x"), ()), Type.IntType))) + ("y", Expr.Global(testPackage, Identifier.Name("x"), ()), Type.IntType) + ) + ) } test("match inference") { testType( - matche(lit(10), + matche( + lit(10), NonEmptyList.of( (Pattern.WildCard, lit(0)) - )), Type.IntType) + ) + ), + Type.IntType + ) testType( - matche(lit(true), + matche( + lit(true), NonEmptyList.of( (Pattern.WildCard, lit(0)) - )), Type.IntType) + ) + ), + Type.IntType + ) testType( - matche(lit(true), + matche( + lit(true), NonEmptyList.of( (Pattern.Annotation(Pattern.WildCard, Type.BoolType), lit(0)) - )), Type.IntType) + ) + ), + Type.IntType + ) } object OptionTypes { @@ -242,11 +314,23 @@ class RankNInferTest extends AnyFunSuite { val pn = testPackage val definedOption = Map( ((pn, Constructor("Some")), (Nil, List(Type.IntType), optName)), - ((pn, Constructor("None")), (Nil, Nil, optName))) + ((pn, Constructor("None")), (Nil, Nil, optName)) + ) val definedOptionGen = Map( - ((pn, Constructor("Some")), (List((Bound("a"), Kind.Type.co)), List(Type.TyVar(Bound("a"))), optName)), - ((pn, Constructor("None")), (List((Bound("a"), Kind.Type.co)), Nil, optName))) + ( + (pn, Constructor("Some")), + ( + List((Bound("a"), Kind.Type.co)), + List(Type.TyVar(Bound("a"))), + optName + ) + ), + ( + (pn, Constructor("None")), + (List((Bound("a"), Kind.Type.co)), Nil, optName) + ) + ) } test("match with custom non-generic types") { @@ -258,41 +342,65 @@ class RankNInferTest extends AnyFunSuite { val kinds = Type.builtInKinds.updated(optName, Kind(Kind.Type.co)) def testWithOpt[A: HasRegion](term: Expr[A], ty: Type) = - Infer.typeCheck(term).runFully( - withBools ++ asFullyQualified(constructors), - definedOption ++ boolTypes, - kinds) match { - case Left(err) => assert(false, err) + Infer + .typeCheck(term) + .runFully( + withBools ++ asFullyQualified(constructors), + definedOption ++ boolTypes, + kinds + ) match { + case Left(err) => assert(false, err) case Right(tpe) => assert(tpe.getType == ty, term.toString) } def failWithOpt[A: HasRegion](term: Expr[A]) = - Infer.typeCheck(term).runFully( - withBools ++ asFullyQualified(constructors), - definedOption ++ boolTypes, - kinds) match { + Infer + .typeCheck(term) + .runFully( + withBools ++ asFullyQualified(constructors), + definedOption ++ boolTypes, + kinds + ) match { case Left(_) => assert(true) - case Right(tpe) => assert(false, s"expected to fail, but inferred type $tpe") + case Right(tpe) => + assert(false, s"expected to fail, but inferred type $tpe") } testWithOpt( - matche(app(v("Some"), lit(1)), + matche( + app(v("Some"), lit(1)), NonEmptyList.of( (Pattern.WildCard, lit(0)) - )), Type.IntType) + ) + ), + Type.IntType + ) testWithOpt( - matche(app(v("Some"), lit(1)), + matche( + app(v("Some"), lit(1)), NonEmptyList.of( - (Pattern.PositionalStruct("Some", List(Pattern.Var(Identifier.Name("a")))), v("a")), + ( + Pattern.PositionalStruct( + "Some", + List(Pattern.Var(Identifier.Name("a"))) + ), + v("a") + ), (Pattern.PositionalStruct("None", Nil), lit(42)) - )), Type.IntType) + ) + ), + Type.IntType + ) failWithOpt( - matche(app(v("Some"), lit(1)), + matche( + app(v("Some"), lit(1)), NonEmptyList.of( (Pattern.PositionalStruct("Foo", List(Pattern.WildCard)), lit(0)) - ))) + ) + ) + ) } test("match with custom generic types") { @@ -303,8 +411,17 @@ class RankNInferTest extends AnyFunSuite { val kinds = Type.builtInKinds.updated(optName, Kind(Kind.Type.co)) val constructors = Map( - (Identifier.unsafe("Some"), Type.ForAll(NonEmptyList.of(b("a")), Type.Fun(tv("a"), Type.TyApply(optType, tv("a"))))), - (Identifier.unsafe("None"), Type.ForAll(NonEmptyList.of(b("a")), Type.TyApply(optType, tv("a")))) + ( + Identifier.unsafe("Some"), + Type.ForAll( + NonEmptyList.of(b("a")), + Type.Fun(tv("a"), Type.TyApply(optType, tv("a"))) + ) + ), + ( + Identifier.unsafe("None"), + Type.ForAll(NonEmptyList.of(b("a")), Type.TyApply(optType, tv("a"))) + ) ) def testWithOpt[A: HasRegion](term: Expr[A], ty: Type) = @@ -313,45 +430,77 @@ class RankNInferTest extends AnyFunSuite { .runFully( withBools ++ asFullyQualified(constructors), definedOptionGen ++ boolTypes, - kinds) match { - case Left(err) => assert(false, err) - case Right(tpe) => assert(tpe.getType == ty, term.toString) - } + kinds + ) match { + case Left(err) => assert(false, err) + case Right(tpe) => assert(tpe.getType == ty, term.toString) + } def failWithOpt[A: HasRegion](term: Expr[A]) = - Infer.typeCheck(term).runFully( - withBools ++ asFullyQualified(constructors), - definedOptionGen ++ boolTypes, - kinds) match { + Infer + .typeCheck(term) + .runFully( + withBools ++ asFullyQualified(constructors), + definedOptionGen ++ boolTypes, + kinds + ) match { case Left(_) => assert(true) - case Right(tpe) => assert(false, s"expected to fail, but inferred type $tpe") + case Right(tpe) => + assert(false, s"expected to fail, but inferred type $tpe") } testWithOpt( - matche(app(v("Some"), lit(1)), + matche( + app(v("Some"), lit(1)), NonEmptyList.of( (Pattern.WildCard, lit(0)) - )), Type.IntType) + ) + ), + Type.IntType + ) testWithOpt( - matche(app(v("Some"), lit(1)), + matche( + app(v("Some"), lit(1)), NonEmptyList.of( - (Pattern.PositionalStruct("Some", List(Pattern.Var(Identifier.Name("a")))), v("a")), + ( + Pattern.PositionalStruct( + "Some", + List(Pattern.Var(Identifier.Name("a"))) + ), + v("a") + ), (Pattern.PositionalStruct("None", Nil), lit(42)) - )), Type.IntType) + ) + ), + Type.IntType + ) // Nested Some testWithOpt( - matche(app(v("Some"), app(v("Some"), lit(1))), + matche( + app(v("Some"), app(v("Some"), lit(1))), NonEmptyList.of( - (Pattern.PositionalStruct("Some", List(Pattern.Var(Identifier.Name("a")))), v("a")) - )), Type.TyApply(optType, Type.IntType)) + ( + Pattern.PositionalStruct( + "Some", + List(Pattern.Var(Identifier.Name("a"))) + ), + v("a") + ) + ) + ), + Type.TyApply(optType, Type.IntType) + ) failWithOpt( - matche(app(v("Some"), lit(1)), + matche( + app(v("Some"), lit(1)), NonEmptyList.of( (Pattern.PositionalStruct("Foo", List(Pattern.WildCard)), lit(0)) - ))) + ) + ) + ) } test("Test a constructor with ForAll") { @@ -362,73 +511,130 @@ class RankNInferTest extends AnyFunSuite { val optType: Type.Tau = Type.TyConst(optName) val pn = testPackage - /** - * struct Pure(pure: forall a. a -> f[a]) - */ + + /** struct Pure(pure: forall a. a -> f[a]) + */ val defined = Map( - ((pn, Constructor("Pure")), (List((Type.Var.Bound("f"), Kind(Kind.Type.in).in)), - List(Type.ForAll(NonEmptyList.of((Type.Var.Bound("a"), Kind.Type)), Type.Fun(tv("a"), Type.TyApply(tv("f"), tv("a"))))), - pureName)), - ((pn, Constructor("Some")), (List((Type.Var.Bound("a"), Kind.Type.co)), List(tv("a")), optName)), - ((pn, Constructor("None")), (List((Type.Var.Bound("a"), Kind.Type.co)), Nil, optName))) + ( + (pn, Constructor("Pure")), + ( + List((Type.Var.Bound("f"), Kind(Kind.Type.in).in)), + List( + Type.ForAll( + NonEmptyList.of((Type.Var.Bound("a"), Kind.Type)), + Type.Fun(tv("a"), Type.TyApply(tv("f"), tv("a"))) + ) + ), + pureName + ) + ), + ( + (pn, Constructor("Some")), + (List((Type.Var.Bound("a"), Kind.Type.co)), List(tv("a")), optName) + ), + ( + (pn, Constructor("None")), + (List((Type.Var.Bound("a"), Kind.Type.co)), Nil, optName) + ) + ) val constructors = Map( - (Identifier.unsafe("Pure"), Type.ForAll(NonEmptyList.of(b1("f")), - Type.Fun(Type.ForAll(NonEmptyList.of(b("a")), Type.Fun(tv("a"), Type.TyApply(tv("f"), tv("a")))), - Type.TyApply(Type.TyConst(pureName), tv("f")) ))), - (Identifier.unsafe("Some"), Type.ForAll(NonEmptyList.of(b("a")), Type.Fun(tv("a"), Type.TyApply(optType, tv("a"))))), - (Identifier.unsafe("None"), Type.ForAll(NonEmptyList.of(b("a")), Type.TyApply(optType, tv("a")))) + ( + Identifier.unsafe("Pure"), + Type.ForAll( + NonEmptyList.of(b1("f")), + Type.Fun( + Type.ForAll( + NonEmptyList.of(b("a")), + Type.Fun(tv("a"), Type.TyApply(tv("f"), tv("a"))) + ), + Type.TyApply(Type.TyConst(pureName), tv("f")) + ) + ) + ), + ( + Identifier.unsafe("Some"), + Type.ForAll( + NonEmptyList.of(b("a")), + Type.Fun(tv("a"), Type.TyApply(optType, tv("a"))) + ) + ), + ( + Identifier.unsafe("None"), + Type.ForAll(NonEmptyList.of(b("a")), Type.TyApply(optType, tv("a"))) + ) ) def testWithTypes[A: HasRegion](term: Expr[A], ty: Type) = - Infer.typeCheck(term).runFully( - withBools ++ asFullyQualified(constructors), - defined ++ boolTypes, - Type.builtInKinds.updated(optName, Kind(Kind.Type.co))) match { - case Left(err) => assert(false, err) + Infer + .typeCheck(term) + .runFully( + withBools ++ asFullyQualified(constructors), + defined ++ boolTypes, + Type.builtInKinds.updated(optName, Kind(Kind.Type.co)) + ) match { + case Left(err) => assert(false, err) case Right(tpe) => assert(tpe.getType == ty, term.toString) } testWithTypes( - app(v("Pure"), v("Some")), Type.TyApply(Type.TyConst(pureName), optType)) + app(v("Pure"), v("Some")), + Type.TyApply(Type.TyConst(pureName), optType) + ) } test("test inference of basic expressions") { - parseProgram("""# + parseProgram( + """# main = (x -> x)(1) -""", "Int") +""", + "Int" + ) - parseProgram("""# + parseProgram( + """# x = 1 y = x main = y -""", "Int") +""", + "Int" + ) } test("test inference with partial def annotation") { - parseProgram("""# + parseProgram( + """# ident: forall a. a -> a = x -> x main = ident(1) -""", "Int") +""", + "Int" + ) - parseProgram("""# + parseProgram( + """# def ident(x: a): x main = ident(1) -""", "Int") +""", + "Int" + ) - parseProgram("""# + parseProgram( + """# def ident(x) -> a: x main = ident(1) -""", "Int") +""", + "Int" + ) - parseProgram("""# + parseProgram( + """# enum MyBool: T, F struct Pair(fst, snd) @@ -444,10 +650,12 @@ res = ( ) main = res -""", "Int") - +""", + "Int" + ) - parseProgram("""# + parseProgram( + """# struct Pair(fst: a, snd: a) @@ -460,43 +668,61 @@ fst = ( ) main = fst -""", "Int") +""", + "Int" + ) } test("test inference with some defined types") { - parseProgram("""# + parseProgram( + """# struct Unit main = Unit -""", "Unit") +""", + "Unit" + ) - parseProgram("""# + parseProgram( + """# enum Option: None Some(a) main = Some(1) -""", "Option[Int]") +""", + "Option[Int]" + ) - parseProgram("""# + parseProgram( + """# enum Option: None Some(a) main = Some -""", "forall a. a -> Option[a]") +""", + "forall a. a -> Option[a]" + ) - parseProgram("""# + parseProgram( + """# id = x -> x main = id -""", "forall a. a -> a") +""", + "forall a. a -> a" + ) - parseProgram("""# + parseProgram( + """# id = x -> x main = id(1) -""", "Int") +""", + "Int" + ) - parseProgram("""# + parseProgram( + """# enum Option: None Some(a) @@ -505,9 +731,12 @@ x = Some(1) main = match x: case None: 0 case Some(y): y -""", "Int") +""", + "Int" + ) - parseProgram("""# + parseProgram( + """# enum List: Empty NonEmpty(a: a, tail: b) @@ -516,9 +745,12 @@ x = NonEmpty(1, Empty) main = match x: case Empty: 0 case NonEmpty(y, _): y -""", "Int") +""", + "Int" + ) - parseProgram("""# + parseProgram( + """# enum Opt: None, Some(a) @@ -530,9 +762,12 @@ def optBind(opt, bindFn): case Some(a): bindFn(a) main = Monad(Some, optBind) -""", "Monad[Opt]") +""", + "Monad[Opt]" + ) - parseProgram("""# + parseProgram( + """# enum Opt: None, Some(a) @@ -555,15 +790,18 @@ def use_bind(m, a, b, c): a1.bind(_ -> b1.bind(_ -> c1)) main = use_bind(option_monad, None, None, None) -""", "forall a. Opt[a]") - - // TODO: - // The challenge here is that the naive curried form of the - // def will not see the forall until the final parameter - // we need to bubble up the forall on the whole function. - // - // same as the above with a different order in use_bind - parseProgram("""# +""", + "forall a. Opt[a]" + ) + + // TODO: + // The challenge here is that the naive curried form of the + // def will not see the forall until the final parameter + // we need to bubble up the forall on the whole function. + // + // same as the above with a different order in use_bind + parseProgram( + """# enum Opt: None, Some(a) @@ -586,20 +824,26 @@ def use_bind(a, b, c, m): a1.bind(_ -> b1.bind(_ -> c1)) main = use_bind(None, None, None, option_monad) -""", "forall a. Opt[a]") +""", + "forall a. Opt[a]" + ) } test("test zero arg defs") { - parseProgram("""# + parseProgram( + """# struct Foo fst: Foo = Foo main = fst -""", "Foo") +""", + "Foo" + ) - parseProgram("""# + parseProgram( + """# enum Foo: Bar, Baz(a) @@ -607,13 +851,15 @@ enum Foo: fst: Foo[a] = Bar main = fst -""", "forall a. Foo[a]") +""", + "forall a. Foo[a]" + ) } - test("substition works correctly") { - parseProgram("""# + parseProgram( + """# (id: forall a. a -> a) = x -> x struct Foo @@ -621,9 +867,12 @@ struct Foo def apply(fn, arg: Foo): fn(arg) main = apply(id, Foo) -""", "Foo") +""", + "Foo" + ) - parseProgram("""# + parseProgram( + """# (id: forall a. a -> a) = x -> x struct Foo @@ -633,9 +882,12 @@ struct Foo def apply(fn, arg: Foo): fn(arg) main = apply(id, Foo) -""", "Foo") +""", + "Foo" + ) - parseProgram("""# + parseProgram( + """# struct FnWrapper(fn: a -> a) @@ -650,9 +902,12 @@ def apply(fn, arg: Foo): f(arg) main = apply(id, Foo) -""", "Foo") +""", + "Foo" + ) - parseProgram("""# + parseProgram( + """# struct Foo (id: forall a. a -> Foo) = _ -> Foo @@ -662,7 +917,9 @@ struct Foo (idGen2: (forall a. a) -> Foo) = id2 main = Foo -""", "Foo") +""", + "Foo" + ) parseProgramIllTyped("""# @@ -674,7 +931,8 @@ struct Foo main = Foo """) - parseProgram("""# + parseProgram( + """# enum Foo: Bar, Baz (bar1: forall a. (Foo -> a) -> a) = fn -> fn(Bar) @@ -697,8 +955,11 @@ enum Foo: Bar, Baz (producer1: Foo -> ((Foo -> Foo) -> Foo)) = producer main = Bar -""", "Foo") - parseProgram("""# +""", + "Foo" + ) + parseProgram( + """# enum Foo: Bar, Baz struct Cont(cont: (b -> a) -> a) @@ -723,7 +984,9 @@ struct Cont(cont: (b -> a) -> a) (producer1: Foo -> Cont[Foo, Foo]) = producer main = Bar -""", "Foo") +""", + "Foo" + ) parseProgramIllTyped("""# enum Foo: Bar, Baz @@ -737,7 +1000,8 @@ struct Cont(cont: (b -> a) -> a) main = Bar """) - parseProgram("""# + parseProgram( + """# struct Foo enum Opt: Nope, Yep(a) @@ -749,9 +1013,11 @@ enum Opt: Nope, Yep(a) (consumer1: (forall a. Opt[a]) -> Foo) = consumer main = Foo -""", "Foo") +""", + "Foo" + ) - parseProgramIllTyped("""# + parseProgramIllTyped("""# struct Foo enum Opt: Nope, Yep(a) @@ -761,7 +1027,7 @@ enum Opt: Nope, Yep(a) main = Foo """) - parseProgramIllTyped("""# + parseProgramIllTyped("""# struct Foo enum Opt: Nope, Yep(a) @@ -772,7 +1038,8 @@ enum Opt: Nope, Yep(a) main = Foo """) - parseProgram("""# + parseProgram( + """# struct Foo enum Opt: Nope, Yep(a) @@ -786,7 +1053,9 @@ struct FnWrapper(fn: a -> b) (consumer1: FnWrapper[forall a. Opt[a], Foo]) = consumer main = Foo -""", "Foo") +""", + "Foo" + ) parseProgramIllTyped("""# struct Foo @@ -816,7 +1085,8 @@ main = Foo } test("def with type annotation and use the types inside") { - parseProgram("""# + parseProgram( + """# struct Pair(fst, snd) @@ -825,11 +1095,13 @@ def fst(p: Pair[a, b]) -> a: f main = fst(Pair(1, "1")) -""", "Int") +""", + "Int" + ) } test("test that we see some ill typed programs") { - parseProgramIllTyped("""# + parseProgramIllTyped("""# def foo(i: Int): i @@ -839,7 +1111,7 @@ main = foo("Not an Int") test("using a literal the wrong type is ill-typed") { - parseProgramIllTyped("""# + parseProgramIllTyped("""# x = "foo" @@ -848,7 +1120,7 @@ main = match x: case y: y """) - parseProgramIllTyped("""# + parseProgramIllTyped("""# x = 1 @@ -879,7 +1151,8 @@ main = 1 } test("structural recursion can be typed") { - parseProgram("""# + parseProgram( + """# enum Nat: Zero, Succ(prev: Nat) @@ -889,9 +1162,12 @@ def len(l): case Succ(p): len(p) main = len(Succ(Succ(Zero))) -""", "Int") +""", + "Int" + ) - parseProgram("""# + parseProgram( + """# enum Nat: Zero, Succ(prev: Nat) @@ -903,12 +1179,15 @@ def len(l): len0(l) main = len(Succ(Succ(Zero))) -""", "Int") +""", + "Int" + ) } test("nested def example") { - parseProgram("""# + parseProgram( + """# struct Pair(first, second) def bar(x): @@ -918,12 +1197,15 @@ def bar(x): baz(10) main = bar(5) -""", "Pair[Int, Int]") +""", + "Pair[Int, Int]" + ) } test("test checkRho on annotated lambda") { - parseProgram("""# + parseProgram( + """# struct Foo struct Bar @@ -935,10 +1217,13 @@ struct Bar dontCall = \(_: (forall a. a) -> Bar) -> Foo (main: Foo) = dontCall(fn) -""", "Foo") +""", + "Foo" + ) } test("ForAll as function arg") { - parseProgram("""# + parseProgram( + """# struct Wrap[bbbb](y1: bbbb) struct Foo[cccc](y2: cccc) struct Nil @@ -954,11 +1239,14 @@ def foo(cra_fn: Wrap[(forall ssss. Foo[ssss]) -> Nil]): match cra_fn: case (_: Wrap[(forall x. Foo[x]) -> Nil]): Nil main = foo -""", "Wrap[(forall ssss. Foo[ssss]) -> Nil] -> Nil") +""", + "Wrap[(forall ssss. Foo[ssss]) -> Nil] -> Nil" + ) } test("use a type annotation inside a def") { - parseProgram("""# + parseProgram( + """# struct Foo struct Bar def ignore(_): Foo @@ -966,9 +1254,12 @@ def add(x): (y: Foo) = x _ = ignore(y) Bar -""", "Foo -> Bar") +""", + "Foo -> Bar" + ) - parseProgram("""# + parseProgram( + """# struct Foo struct Bar(f: Foo) def ignore(_): Foo @@ -976,7 +1267,9 @@ def add(x): ((y: Foo) as b) = x _ = ignore(y) Bar(b) -""", "Foo -> Bar") +""", + "Foo -> Bar" + ) } test("top level matches don't introduce colliding bindings") { @@ -1037,7 +1330,8 @@ struct Bar x: Bar = Foo """) - parseProgram("""# + parseProgram( + """# struct Foo struct Bar @@ -1045,9 +1339,12 @@ x = ( f = Foo f: Foo ) -""", "Foo") +""", + "Foo" + ) - parseProgram("""# + parseProgram( + """# struct Foo struct Bar @@ -1055,8 +1352,11 @@ x = ( f: Foo = Foo f ) -""", "Foo") - parseProgram("""# +""", + "Foo" + ) + parseProgram( + """# struct Pair(a, b) struct Foo @@ -1067,9 +1367,12 @@ x = ( _ = ignore(g) f ) -""", "Foo") +""", + "Foo" + ) - parseProgram("""# + parseProgram( + """# struct Pair(a, b) struct Foo @@ -1077,17 +1380,23 @@ x = ( Pair(f, _) = Pair(Foo: Foo, Foo) f ) -""", "Foo") +""", + "Foo" + ) - parseProgram("""# + parseProgram( + """# struct Foo x: Foo = Foo -""", "Foo") +""", + "Foo" + ) } test("test inner quantification") { - parseProgram("""# + parseProgram( + """# struct Foo # this should just be: type Foo @@ -1099,6 +1408,8 @@ foo = ( ident(Foo) ) -""", "Foo") +""", + "Foo" + ) } } diff --git a/core/src/test/scala/org/bykn/bosatsu/rankn/TypeTest.scala b/core/src/test/scala/org/bykn/bosatsu/rankn/TypeTest.scala index 2e60cc99a..33e3bceec 100644 --- a/core/src/test/scala/org/bykn/bosatsu/rankn/TypeTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/rankn/TypeTest.scala @@ -3,18 +3,21 @@ package org.bykn.bosatsu.rankn import cats.data.NonEmptyList import org.bykn.bosatsu.Kind import org.scalacheck.Gen -import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ forAll, PropertyCheckConfiguration } +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks.{ + forAll, + PropertyCheckConfiguration +} import org.scalatest.funsuite.AnyFunSuite class TypeTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - //PropertyCheckConfiguration(minSuccessful = 5000) + // PropertyCheckConfiguration(minSuccessful = 5000) PropertyCheckConfiguration(minSuccessful = 500) - //PropertyCheckConfiguration(minSuccessful = 5) + // PropertyCheckConfiguration(minSuccessful = 5) def parse(s: String): Type = Type.fullyResolvedParser.parseAll(s) match { - case Right(t) => t + case Right(t) => t case Left(err) => sys.error(err.toString) } @@ -29,13 +32,15 @@ class TypeTest extends AnyFunSuite { forAll(Gen.listOf(NTypeGen.genDepth03)) { ts => Type.Tuple(ts) match { case Type.Tuple(ts1) => assert(ts1 == ts) - case notTup => fail(notTup.toString) + case notTup => fail(notTup.toString) } } assert(Type.Tuple.unapply(parse("()")) == Some(Nil)) - assert(Type.Tuple.unapply(parse("(a, b, c)")) == - Some(List("a", "b", "c").map(parse))) + assert( + Type.Tuple.unapply(parse("(a, b, c)")) == + Some(List("a", "b", "c").map(parse)) + ) } test("unapplyAll is the inverse of applyAll") { @@ -44,8 +49,10 @@ class TypeTest extends AnyFunSuite { assert(Type.applyAll(left, args) == ts) } - assert(Type.unapplyAll(parse("foo[bar]")) == - (parse("foo"), List(parse("bar")))) + assert( + Type.unapplyAll(parse("foo[bar]")) == + (parse("foo"), List(parse("bar"))) + ) } test("types are well ordered") { @@ -68,16 +75,16 @@ class TypeTest extends AnyFunSuite { forAll(NTypeGen.genDepth03)(law(_)) - - forAll(NTypeGen.lowerIdent, Gen.choose(Long.MinValue, Long.MaxValue)) { (b, id) => - val str = "$" + b + "$" + id.toString - val tpe = parse(str) - law(tpe) - tpe match { - case Type.TyVar(Type.Var.Skolem(b1, k1, i1)) => - assert((b1, k1, i1) === (b, Kind.Type ,id)) - case other => fail(other.toString) - } + forAll(NTypeGen.lowerIdent, Gen.choose(Long.MinValue, Long.MaxValue)) { + (b, id) => + val str = "$" + b + "$" + id.toString + val tpe = parse(str) + law(tpe) + tpe match { + case Type.TyVar(Type.Var.Skolem(b1, k1, i1)) => + assert((b1, k1, i1) === (b, Kind.Type, id)) + case other => fail(other.toString) + } } forAll { (l: Long) => @@ -86,8 +93,10 @@ class TypeTest extends AnyFunSuite { } test("test all binders") { - assert(Type.allBinders.filter(_.name.startsWith("a")).take(100).map(_.name) == - ("a" #:: Stream.iterate(0)(_ + 1).map { i => s"a$i" }).take(100)) + assert( + Type.allBinders.filter(_.name.startsWith("a")).take(100).map(_.name) == + ("a" #:: Stream.iterate(0)(_ + 1).map { i => s"a$i" }).take(100) + ) } test("tyVarBinders is identity for Bound") { @@ -109,7 +118,7 @@ class TypeTest extends AnyFunSuite { test("hasNoVars fully recurses") { def allTypesIn(t: Type): List[Type] = t match { - case f@Type.ForAll(bounds, in) => + case f @ Type.ForAll(bounds, in) => // filter bounds out, since they are shadowed val boundSet = bounds.toList.iterator.map(_._1).toSet[Type.Var] f :: (allTypesIn(in).filterNot { it => @@ -117,8 +126,8 @@ class TypeTest extends AnyFunSuite { // if we intersect, this is not a legit type to consider (boundSet & frees).nonEmpty }) - case t@Type.TyApply(a, b) => t :: allTypesIn(a) ::: allTypesIn(b) - case other => other :: Nil + case t @ Type.TyApply(a, b) => t :: allTypesIn(a) ::: allTypesIn(b) + case other => other :: Nil } def law(t: Type) = { @@ -133,10 +142,19 @@ class TypeTest extends AnyFunSuite { val pastFails = List( - Type.ForAll(NonEmptyList.of((Type.Var.Bound("x"), Kind.Type), (Type.Var.Bound("ogtumm"), Kind.Type), (Type.Var.Bound("t"), Kind.Type)), - Type.TyVar(Type.Var.Bound("x"))), - Type.ForAll(NonEmptyList.of((Type.Var.Bound("a"), Kind.Type)),Type.TyVar(Type.Var.Bound("a"))) + Type.ForAll( + NonEmptyList.of( + (Type.Var.Bound("x"), Kind.Type), + (Type.Var.Bound("ogtumm"), Kind.Type), + (Type.Var.Bound("t"), Kind.Type) + ), + Type.TyVar(Type.Var.Bound("x")) + ), + Type.ForAll( + NonEmptyList.of((Type.Var.Bound("a"), Kind.Type)), + Type.TyVar(Type.Var.Bound("a")) ) + ) pastFails.foreach(law) } @@ -174,7 +192,8 @@ class TypeTest extends AnyFunSuite { def genSubs(depth: Int): Gen[Map[Type.Var, Type]] = { val pair = Gen.zip( NTypeGen.genBound, - NTypeGen.genDepth(depth, Some(NTypeGen.genConst))) + NTypeGen.genDepth(depth, Some(NTypeGen.genConst)) + ) Gen.mapOf(pair) } @@ -219,7 +238,9 @@ class TypeTest extends AnyFunSuite { // now subs1 has keys that can be completely removed, so // after substitution, those keys should be gone val t1 = Type.substituteVar(t, subs1) - assert((Type.freeBoundTyVars(t1 :: Nil).toSet & subs1.keySet) == Set.empty) + assert( + (Type.freeBoundTyVars(t1 :: Nil).toSet & subs1.keySet) == Set.empty + ) } forAll(NTypeGen.genDepth03, genSubs(3))(law _) @@ -233,7 +254,7 @@ class TypeTest extends AnyFunSuite { } yield NonEmptyList(head, tail) forAll(genArgs, NTypeGen.genDepth03) { (args, res) => - val fnType = Type.Fun(args, res) + val fnType = Type.Fun(args, res) fnType match { case Type.Fun(args1, res1) => assert(args1 == args) diff --git a/jsapi/src/main/scala/org/bykn/bosatsu/jsapi/JsApi.scala b/jsapi/src/main/scala/org/bykn/bosatsu/jsapi/JsApi.scala index 7929ba8a5..122522bda 100644 --- a/jsapi/src/main/scala/org/bykn/bosatsu/jsapi/JsApi.scala +++ b/jsapi/src/main/scala/org/bykn/bosatsu/jsapi/JsApi.scala @@ -25,12 +25,14 @@ object JsApi { class EvalSuccess(val result: js.Any) extends js.Object - /** - * mainPackage can be null, in which case we find the package - * in mainFile - */ + /** mainPackage can be null, in which case we find the package in mainFile + */ @JSExport - def evaluate(mainPackage: String, mainFile: String, files: js.Dictionary[String]): EvalSuccess | Error = { + def evaluate( + mainPackage: String, + mainFile: String, + files: js.Dictionary[String] + ): EvalSuccess | Error = { val baseArgs = "--package_root" :: "" :: "--color" :: "html" :: Nil val main = if (mainPackage != null) "--main" :: mainPackage :: baseArgs @@ -39,8 +41,9 @@ object JsApi { case Left(err) => new Error(s"error: ${err.getMessage}") case Right(module.Output.EvaluationResult(_, tpe, resDoc)) => - val tDoc = rankn.Type.fullyResolvedDocument.document(tpe) - val doc = resDoc.value + (Doc.lineOrEmpty + Doc.text(": ") + tDoc).nested(4) + val tDoc = rankn.Type.fullyResolvedDocument.document(tpe) + val doc = + resDoc.value + (Doc.lineOrEmpty + Doc.text(": ") + tDoc).nested(4) new EvalSuccess(doc.render(80)) case Right(other) => new Error(s"internal error. got unexpected result: $other") @@ -49,38 +52,43 @@ object JsApi { def jsonToAny(j: Json): js.Any = j match { - case Json.JString(s) => s + case Json.JString(s) => s case Json.JNumberStr(str) => // javascript only really has doubles try str.toDouble catch { case (_: NumberFormatException) => str } - case Json.JBool.True => true + case Json.JBool.True => true case Json.JBool.False => false - case Json.JNull => null + case Json.JNull => null case Json.JArray(items) => val ary = new js.Array[js.Any] items.foreach { j => ary += jsonToAny(j) } ary case Json.JObject(kvs) => - js.Dictionary[js.Any]( - kvs.map { case (k, v) => - (k, jsonToAny(v)) - } :_*) + js.Dictionary[js.Any](kvs.map { case (k, v) => + (k, jsonToAny(v)) + }: _*) } - /** - * mainPackage can be null, in which case we find the package - * in mainFile - */ + /** mainPackage can be null, in which case we find the package in mainFile + */ @JSExport - def evaluateJson(mainPackage: String, mainFile: String, files: js.Dictionary[String]): EvalSuccess | Error = { + def evaluateJson( + mainPackage: String, + mainFile: String, + files: js.Dictionary[String] + ): EvalSuccess | Error = { val baseArgs = "--package_root" :: "" :: "--color" :: "html" :: Nil val main = if (mainPackage != null) "--main" :: mainPackage :: baseArgs else "--main_file" :: mainFile :: baseArgs - module.runWith(files)("json" :: "write" :: "--output" :: "" :: main ::: makeInputArgs(files.keys)) match { + module.runWith(files)( + "json" :: "write" :: "--output" :: "" :: main ::: makeInputArgs( + files.keys + ) + ) match { case Left(err) => new Error(s"error: ${err.getMessage}") case Right(module.Output.JsonOutput(json, _)) => diff --git a/project/plugins.sbt b/project/plugins.sbt index 238967dde..5bd035322 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -9,6 +9,5 @@ addSbtPlugin("org.scalameta" % "sbt-native-image" % "0.3.4") addSbtPlugin("org.scoverage" % "sbt-scoverage" % "2.0.9") addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.4.6") - // This is adding this compiler plugin as a dependency for the build, not the code itself libraryDependencies += "com.thesamet.scalapb" %% "compilerplugin" % "0.11.13"