Skip to content

Commit

Permalink
add some more test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek committed Sep 19, 2023
1 parent 5f1a233 commit 567a257
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 9 deletions.
47 changes: 44 additions & 3 deletions core/src/main/scala/org/bykn/bosatsu/Expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ sealed abstract class Expr[T] {

lazy val globals: Set[Expr.Global[T]] = {
import Expr._
// nearly identical code to TypedExpr.freeVarsDup, bugs should be fixed in both places
this match {
case Generic(_, expr) =>
expr.globals
Expand All @@ -103,7 +102,7 @@ sealed abstract class Expr[T] {
def replaceTag(t: T): Expr[T] = {
import Expr._
this match {
case Generic(_, _) => this
case g@Generic(_, e) => g.copy(in = e.replaceTag(t))
case a@Annotation(_, _, _) => a.copy(tag = t)
case l@Local(_, _) => l.copy(tag = t)
case g @ Global(_, _, _) => g.copy(tag = t)
Expand All @@ -130,10 +129,52 @@ object Expr {
case class Global[T](pack: PackageName, name: Identifier, tag: T) extends Name[T]
case class App[T](fn: Expr[T], args: NonEmptyList[Expr[T]], tag: T) extends Expr[T]
case class Lambda[T](args: NonEmptyList[(Bindable, Option[Type])], expr: Expr[T], tag: T) extends Expr[T]
case class Let[T](arg: Bindable, expr: Expr[T], in: Expr[T], recursive: RecursionKind, tag: T) extends Expr[T]
case class Let[T](arg: Bindable, expr: Expr[T], in: Expr[T], recursive: RecursionKind, tag: T) extends Expr[T] {
def flatten: (NonEmptyList[(Bindable, RecursionKind, Expr[T], T)], Expr[T]) = {
val thisLet = (arg, recursive, expr, tag)

in match {
case let@Let(_, _, _, _, _) =>
val (lets, finalIn) = let.flatten
(thisLet :: lets, finalIn)
case _ =>
// this is the final let
(NonEmptyList.one(thisLet), in)
}
}
}
case class Literal[T](lit: Lit, tag: T) extends Expr[T]
case class Match[T](arg: Expr[T], branches: NonEmptyList[(Pattern[(PackageName, Constructor), Type], Expr[T])], tag: T) extends Expr[T]

// Inverse of `Let.flatten`
def lets[T](binds: List[(Bindable, RecursionKind, Expr[T], T)], in: Expr[T]): Expr[T] =
binds match {
case Nil => in
case (b, r, e, t) :: tail =>
val res = lets(tail, in)
Let(b, e, res, r, t)
}

object Annotated {
def unapply[A](expr: Expr[A]): Option[Type] =
expr match {
case Annotation(_, tpe, _) => Some(tpe)
case Lambda(args, Annotated(res), _) =>
args.traverse { case (_, ot) => ot }
.map { argTpes =>
Type.Fun(argTpes, res)
}
case Literal(lit, _) => Some(Type.getTypeOf(lit))
case Let(_, _, Annotated(t), _, _) => Some(t)
case Match(_, branches, _) =>
branches.traverse { case (_, expr) => unapply(expr) }
.flatMap { allAnnotated =>
if (allAnnotated.tail.forall(_ === allAnnotated.head)) Some(allAnnotated.head)
else None
}
case _ => None
}
}

def forAll[A](tpeArgs: List[(Type.Var.Bound, Kind)], expr: Expr[A]): Expr[A] =
NonEmptyList.fromList(tpeArgs) match {
Expand Down
12 changes: 6 additions & 6 deletions core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -887,9 +887,9 @@ object Infer {
// compilers/evaluation can possibly optimize non-recursive
// cases differently
val rhsBody = rhs match {
case Annotation(expr, tpe, _) =>
case Expr.Annotated(tpe) =>
extendEnv(name, tpe) {
checkSigma(expr, tpe).parProduct(typeCheckRho(body, expect))
checkSigma(rhs, tpe).parProduct(typeCheckRho(body, expect))
}
case _ =>
newMetaType(Kind.Type) // the kind of a let value is a Type
Expand Down Expand Up @@ -921,9 +921,9 @@ object Infer {
// so any recursion in this case won't typecheck, and shadowing rules are
// in place
val rhsBody = rhs match {
case Annotation(expr, tpe, _) =>
case Expr.Annotated(tpe) =>
// check in parallel so we collect more errors
checkSigma(expr, tpe)
checkSigma(rhs, tpe)
.parProduct(
extendEnv(name, tpe) { typeCheckRho(body, expect) }
)
Expand Down Expand Up @@ -1400,8 +1400,8 @@ object Infer {
private def recursiveTypeCheck[A: HasRegion](name: Bindable, expr: Expr[A]): Infer[TypedExpr[A]] =
// values are of kind Type
expr match {
case Expr.Annotation(e, tpe, _) =>
extendEnv(name, tpe)(checkSigma(e, tpe))
case Expr.Annotated(tpe) =>
extendEnv(name, tpe)(checkSigma(expr, tpe))
case _ =>
newMetaType(Kind.Type).flatMap { tpe =>
extendEnv(name, tpe)(typeCheckMeta(expr, Some((name, tpe, region(expr)))))
Expand Down
34 changes: 34 additions & 0 deletions core/src/test/scala/org/bykn/bosatsu/Gen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1233,4 +1233,38 @@ object Generators {

def genPackage[A](genA: Gen[A], maxSize: Int): Gen[Map[PackageName, Package.Typed[A]]] =
genPackagesSt(genA, maxSize).runS(Map.empty)

object Exprs {
def gen[A](genA: Gen[A], depth: Int): Gen[Expr[A]] = {
import Expr._

val roots: Gen[Expr[A]] =
Gen.frequency(
(1, Gen.zip(genLit, genA).map { case (l, t) => Literal(l, t) }),
(1, Gen.zip(bindIdentGen, genA).map { case (b, t) => Local(b, t) }),
(1, Gen.zip(NTypeGen.packageNameGen, identifierGen, genA).map { case (p, i, t) => Global(p, i, t) })
)

if (depth <= 0) roots
else {
val recur = Gen.lzy(gen(genA, depth - 1))
Gen.frequency(
(1, roots),
(1, Gen.zip(recur, NTypeGen.genDepth03, genA).map { case (e, t, tag) => Annotation(e, t, tag) } ),
(1, Gen.zip(smallNonEmptyList(Gen.zip(NTypeGen.genBound, NTypeGen.genKind), 4), recur).map { case (ts, in) =>
Generic(ts, in)
}),
(2, Gen.zip(recur, smallNonEmptyList(recur, 5), genA).map { case (fn, as, t) => App(fn, as, t) }),
(2, Gen.zip(smallNonEmptyList(Gen.zip(bindIdentGen, Gen.option(NTypeGen.genDepth03)), 4), recur, genA).map { case (as, e, t) =>
Lambda(as, e, t)
}),
(4, Gen.zip(bindIdentGen, recur, recur, Gen.oneOf(RecursionKind.Recursive, RecursionKind.NonRecursive), genA)
.map { case (a, e, in, r, t) => Let(a, e, in, r, t) }),
(1, Gen.zip(recur,
smallNonEmptyList(Gen.zip(genCompiledPattern(4), recur), 3),
genA).map { case (a, bs, t) => Match(a, bs, t)})
)
}
}
}
}

0 comments on commit 567a257

Please sign in to comment.