diff --git a/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala b/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala index 15c338bf3..2707e8fe7 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala @@ -1235,11 +1235,11 @@ object TypedExpr { def forAll[A](params: NonEmptyList[(Type.Var.Bound, Kind)], expr: TypedExpr[A]): TypedExpr[A] = quantVars(forallList = params.toList, Nil, expr) - private def normalizeQuantVars[A](q: Type.Quantification, expr: TypedExpr[A]): TypedExpr[A] = + def normalizeQuantVars[A](q: Type.Quantification, expr: TypedExpr[A]): TypedExpr[A] = expr match { case Generic(oldQuant, ex0) => normalizeQuantVars(q.concat(oldQuant), ex0) - case Annotation(term, _) if Type.quantify(q, expr.getType).sameAs(term.getType) => + case Annotation(term, tpe) if Type.quantify(q, tpe).sameAs(term.getType) => // we not uncommonly add an annotation just to make a generic wrapper to get back where term case _ => @@ -1257,10 +1257,8 @@ object TypedExpr { case Some(q) => val varSet = q.vars.iterator.map { case (b, _) => b }.toSet - val avoid: SortedSet[Type.Var.Bound] = - expr.allTypes.collect { - case Type.TyVar(b: Type.Var.Bound) if !varSet(b) => b - } + val avoid: Set[Type.Var.Bound] = + expr.allBound.diff(varSet) q match { case ForAll(vars) => @@ -1304,13 +1302,11 @@ object TypedExpr { def quantVars[A]( forallList: List[(Type.Var.Bound, Kind)], existList: List[(Type.Var.Bound, Kind)], - expr: TypedExpr[A]): TypedExpr[A] = { - + expr: TypedExpr[A]): TypedExpr[A] = Type.Quantification.fromLists(forallList = forallList, existList = existList) match { - case Some(q) => normalizeQuantVars(q, expr) + case Some(q) => Generic(q, expr) case None => expr } - } private def lambda[A](args: NonEmptyList[(Bindable, Type)], expr: TypedExpr[A], tag: A): TypedExpr[A] = AnnotatedLambda(args, expr, tag) diff --git a/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala b/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala index 04c380718..4c2dd898a 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala @@ -92,7 +92,7 @@ object TypedExprNormalization { normalize1(namerec, term, scope, typeEnv) case Generic(quant, in) => val sin = normalize1(namerec, in, scope, typeEnv).get - val g1 = TypedExpr.quantVars(quant.forallList, quant.existList, sin) + val g1 = TypedExpr.normalizeQuantVars(quant, sin) if (g1 == te) None else Some(g1) case Annotation(term, tpe) => diff --git a/core/src/test/scala/org/bykn/bosatsu/rankn/NTypeGen.scala b/core/src/test/scala/org/bykn/bosatsu/rankn/NTypeGen.scala index b63a7523a..72d18595f 100644 --- a/core/src/test/scala/org/bykn/bosatsu/rankn/NTypeGen.scala +++ b/core/src/test/scala/org/bykn/bosatsu/rankn/NTypeGen.scala @@ -144,6 +144,23 @@ object NTypeGen { } yield TyApply(TyApply(cons, param1), param2))) } + val genQuantArgs: Gen[List[(Type.Var.Bound, Kind)]] = + for { + c <- Gen.choose(0, 5) + ks = NTypeGen.genKind + as <- Gen.listOfN(c, Gen.zip(genBound, ks)) + } yield as + + lazy val genQuant: Gen[Type.Quantification] = + Gen.zip(genQuantArgs, genQuantArgs) + .flatMap { case (fa, ex) => + Type.Quantification.fromLists(fa, ex) match { + case Some(q) => Gen.const(q) + case None => genQuant + } + } + + def genDepth(d: Int, genC: Option[Gen[Type.Const]]): Gen[Type] = if (d <= 0) genRootType(genC) else { @@ -164,11 +181,17 @@ object NTypeGen { in <- recurse } yield Type.exists(as, in) + val genQ = Gen.zip(NTypeGen.genQuant, recurse).map { case (q, t) => + Type.quantify(q, t) + } + val genApply = Gen.zip(recurse, recurse).map { case (a, b) => Type.TyApply(a, b) } - Gen.oneOf(recurse, genApply, genForAll, genExists) + Gen.frequency( + (2, recurse), + (1, genApply), + (1, Gen.oneOf(genForAll, genExists, genQ))) } - val genDepth03: Gen[Type] = Gen.choose(0, 3).flatMap(genDepth(_, Some(genConst))) } diff --git a/core/src/test/scala/org/bykn/bosatsu/rankn/TypeTest.scala b/core/src/test/scala/org/bykn/bosatsu/rankn/TypeTest.scala index bd98e2bd5..39a7ae926 100644 --- a/core/src/test/scala/org/bykn/bosatsu/rankn/TypeTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/rankn/TypeTest.scala @@ -245,4 +245,50 @@ class TypeTest extends AnyFunSuite { } } } + + test("Quantification.concat is associative") { + forAll(NTypeGen.genQuant, NTypeGen.genQuant, NTypeGen.genQuant) { (a, b, c) => + assert(a.concat(b).concat(c) == a.concat(b.concat(c))) + } + } + + test("Quantification.toLists/fromList identity") { + forAll(NTypeGen.genQuant) { q => + assert(Type.Quantification.fromLists(q.forallList, q.existList) == Some(q)) + } + } + + test("unexists/exists | unforall/forall iso") { + forAll(NTypeGen.genDepth03) { + case t@Type.Exists(ps, in) => + assert(Type.exists(ps, in) == t) + case t@Type.ForAll(ps, in) => + assert(Type.forAll(ps, in) == t) + case _ => () + } + } + test("exists -> unexists") { + forAll(NTypeGen.genQuantArgs, NTypeGen.genRootType(None)) { (args, t) => + Type.exists(args, t) match { + case Type.Exists(ps, in) => + assert(ps.toList == args) + assert(in == t) + case notExists => + assert(args.isEmpty) + assert(notExists == t) + } + } + } + test("forall -> unforall") { + forAll(NTypeGen.genQuantArgs, NTypeGen.genRootType(None)) { (args, t) => + Type.forAll(args, t) match { + case Type.ForAll(ps, in) => + assert(ps.toList == args) + assert(in == t) + case notExists => + assert(args.isEmpty) + assert(notExists == t) + } + } + } }