Skip to content

Commit

Permalink
add more Type tests
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek committed Nov 21, 2023
1 parent 489d417 commit 284098f
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 13 deletions.
16 changes: 6 additions & 10 deletions core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1235,11 +1235,11 @@ object TypedExpr {
def forAll[A](params: NonEmptyList[(Type.Var.Bound, Kind)], expr: TypedExpr[A]): TypedExpr[A] =
quantVars(forallList = params.toList, Nil, expr)

private def normalizeQuantVars[A](q: Type.Quantification, expr: TypedExpr[A]): TypedExpr[A] =
def normalizeQuantVars[A](q: Type.Quantification, expr: TypedExpr[A]): TypedExpr[A] =
expr match {
case Generic(oldQuant, ex0) =>
normalizeQuantVars(q.concat(oldQuant), ex0)
case Annotation(term, _) if Type.quantify(q, expr.getType).sameAs(term.getType) =>
case Annotation(term, tpe) if Type.quantify(q, tpe).sameAs(term.getType) =>
// we not uncommonly add an annotation just to make a generic wrapper to get back where
term
case _ =>
Expand All @@ -1257,10 +1257,8 @@ object TypedExpr {
case Some(q) =>
val varSet = q.vars.iterator.map { case (b, _) => b }.toSet

val avoid: SortedSet[Type.Var.Bound] =
expr.allTypes.collect {
case Type.TyVar(b: Type.Var.Bound) if !varSet(b) => b
}
val avoid: Set[Type.Var.Bound] =
expr.allBound.diff(varSet)

q match {
case ForAll(vars) =>
Expand Down Expand Up @@ -1304,13 +1302,11 @@ object TypedExpr {
def quantVars[A](
forallList: List[(Type.Var.Bound, Kind)],
existList: List[(Type.Var.Bound, Kind)],
expr: TypedExpr[A]): TypedExpr[A] = {

expr: TypedExpr[A]): TypedExpr[A] =
Type.Quantification.fromLists(forallList = forallList, existList = existList) match {
case Some(q) => normalizeQuantVars(q, expr)
case Some(q) => Generic(q, expr)
case None => expr
}
}

private def lambda[A](args: NonEmptyList[(Bindable, Type)], expr: TypedExpr[A], tag: A): TypedExpr[A] =
AnnotatedLambda(args, expr, tag)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ object TypedExprNormalization {
normalize1(namerec, term, scope, typeEnv)
case Generic(quant, in) =>
val sin = normalize1(namerec, in, scope, typeEnv).get
val g1 = TypedExpr.quantVars(quant.forallList, quant.existList, sin)
val g1 = TypedExpr.normalizeQuantVars(quant, sin)
if (g1 == te) None
else Some(g1)
case Annotation(term, tpe) =>
Expand Down
27 changes: 25 additions & 2 deletions core/src/test/scala/org/bykn/bosatsu/rankn/NTypeGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,23 @@ object NTypeGen {
} yield TyApply(TyApply(cons, param1), param2)))
}

val genQuantArgs: Gen[List[(Type.Var.Bound, Kind)]] =
for {
c <- Gen.choose(0, 5)
ks = NTypeGen.genKind
as <- Gen.listOfN(c, Gen.zip(genBound, ks))
} yield as

lazy val genQuant: Gen[Type.Quantification] =
Gen.zip(genQuantArgs, genQuantArgs)
.flatMap { case (fa, ex) =>
Type.Quantification.fromLists(fa, ex) match {
case Some(q) => Gen.const(q)
case None => genQuant
}
}


def genDepth(d: Int, genC: Option[Gen[Type.Const]]): Gen[Type] =
if (d <= 0) genRootType(genC)
else {
Expand All @@ -164,11 +181,17 @@ object NTypeGen {
in <- recurse
} yield Type.exists(as, in)

val genQ = Gen.zip(NTypeGen.genQuant, recurse).map { case (q, t) =>
Type.quantify(q, t)
}

val genApply = Gen.zip(recurse, recurse).map { case (a, b) => Type.TyApply(a, b) }

Gen.oneOf(recurse, genApply, genForAll, genExists)
Gen.frequency(
(2, recurse),
(1, genApply),
(1, Gen.oneOf(genForAll, genExists, genQ)))
}


val genDepth03: Gen[Type] = Gen.choose(0, 3).flatMap(genDepth(_, Some(genConst)))
}
46 changes: 46 additions & 0 deletions core/src/test/scala/org/bykn/bosatsu/rankn/TypeTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -245,4 +245,50 @@ class TypeTest extends AnyFunSuite {
}
}
}

test("Quantification.concat is associative") {
forAll(NTypeGen.genQuant, NTypeGen.genQuant, NTypeGen.genQuant) { (a, b, c) =>
assert(a.concat(b).concat(c) == a.concat(b.concat(c)))
}
}

test("Quantification.toLists/fromList identity") {
forAll(NTypeGen.genQuant) { q =>
assert(Type.Quantification.fromLists(q.forallList, q.existList) == Some(q))
}
}

test("unexists/exists | unforall/forall iso") {
forAll(NTypeGen.genDepth03) {
case t@Type.Exists(ps, in) =>
assert(Type.exists(ps, in) == t)
case t@Type.ForAll(ps, in) =>
assert(Type.forAll(ps, in) == t)
case _ => ()
}
}
test("exists -> unexists") {
forAll(NTypeGen.genQuantArgs, NTypeGen.genRootType(None)) { (args, t) =>
Type.exists(args, t) match {
case Type.Exists(ps, in) =>
assert(ps.toList == args)
assert(in == t)
case notExists =>
assert(args.isEmpty)
assert(notExists == t)
}
}
}
test("forall -> unforall") {
forAll(NTypeGen.genQuantArgs, NTypeGen.genRootType(None)) { (args, t) =>
Type.forAll(args, t) match {
case Type.ForAll(ps, in) =>
assert(ps.toList == args)
assert(in == t)
case notExists =>
assert(args.isEmpty)
assert(notExists == t)
}
}
}
}

0 comments on commit 284098f

Please sign in to comment.