From b28b02398ba0a309a470fdc9740d54886d1ef138 Mon Sep 17 00:00:00 2001 From: Patrick Oscar Boykin Date: Tue, 21 Nov 2023 16:04:52 -1000 Subject: [PATCH] address some things noticed in review --- cli/src/main/protobuf/bosatsu/TypedAst.proto | 20 +++-- .../org/bykn/bosatsu/TypedExprToProto.scala | 57 +++++++-------- .../scala/org/bykn/bosatsu/TypeParser.scala | 4 +- .../scala/org/bykn/bosatsu/TypedExpr.scala | 19 ++--- .../scala/org/bykn/bosatsu/rankn/Infer.scala | 73 +++++++++---------- .../scala/org/bykn/bosatsu/rankn/Type.scala | 29 +++++--- .../scala/org/bykn/bosatsu/TestUtils.scala | 5 -- .../org/bykn/bosatsu/TypedExprTest.scala | 8 ++ .../bykn/bosatsu/rankn/RankNInferTest.scala | 2 + 9 files changed, 106 insertions(+), 111 deletions(-) diff --git a/cli/src/main/protobuf/bosatsu/TypedAst.proto b/cli/src/main/protobuf/bosatsu/TypedAst.proto index ecf80d99c..f29f8ddd5 100644 --- a/cli/src/main/protobuf/bosatsu/TypedAst.proto +++ b/cli/src/main/protobuf/bosatsu/TypedAst.proto @@ -35,16 +35,19 @@ message TypeVar { int32 varName = 1; } +message VarKind { + int32 varName = 1; + Kind kind = 2; +} + message TypeForAll { - repeated int32 varNames = 1; - repeated Kind kinds = 2; - int32 typeValue = 3; + repeated VarKind varKinds = 1; + int32 typeValue = 2; } message TypeExists { - repeated int32 varNames = 1; - repeated Kind kinds = 2; - int32 typeValue = 3; + repeated VarKind varKinds = 1; + int32 typeValue = 2; } // represents left[right] type application @@ -160,11 +163,6 @@ enum RecursionKind { IsRec = 1; } -message VarKind { - int32 varName = 1; - Kind kind = 2; -} - message GenericExpr { repeated VarKind forAlls = 1; repeated VarKind exists = 2; diff --git a/cli/src/main/scala/org/bykn/bosatsu/TypedExprToProto.scala b/cli/src/main/scala/org/bykn/bosatsu/TypedExprToProto.scala index 559c1e143..800ce9500 100644 --- a/cli/src/main/scala/org/bykn/bosatsu/TypedExprToProto.scala +++ b/cli/src/main/scala/org/bykn/bosatsu/TypedExprToProto.scala @@ -239,6 +239,10 @@ object ProtoConverter { def str(i: Int): Try[String] = ds.tryString(i - 1, s"invalid string idx: $i in $p") + def varKindFromProto(vk: proto.VarKind) = + (str(vk.varName), kindFromProto(vk.kind)) + .mapN { (n, k) => (Type.Var.Bound(n), k) } + p.value match { case Value.Empty => Failure(new Exception(s"empty type found in $p")) case Value.TypeConst(tc) => @@ -251,25 +255,17 @@ object ProtoConverter { } case Value.TypeVar(tv) => str(tv.varName).map { n => Type.TyVar(Type.Var.Bound(n)) } - case Value.TypeForAll(TypeForAll(args, kinds, in, _)) => - if (args.length != kinds.length) - Failure(new Exception(s"args and kinds len mismatch: $p")) - else - for { - inT <- tpe(in) - args <- args.toList.traverse(str(_).map(Type.Var.Bound(_))) - kinds <- kinds.traverse { k => kindFromProto(Some(k)) } - } yield Type.forAll(args.zip(kinds), inT) - - case Value.TypeExists(TypeExists(args, kinds, in, _)) => - if (args.length != kinds.length) - Failure(new Exception(s"args and kinds len mismatch: $p")) - else - for { - inT <- tpe(in) - args <- args.toList.traverse(str(_).map(Type.Var.Bound(_))) - kinds <- kinds.traverse { k => kindFromProto(Some(k)) } - } yield Type.exists(args.zip(kinds), inT) + case Value.TypeForAll(TypeForAll(varKinds, in, _)) => + for { + inT <- tpe(in) + args <- varKinds.toList.traverse(varKindFromProto) + } yield Type.forAll(args, inT) + + case Value.TypeExists(TypeExists(varKinds, in, _)) => + for { + inT <- tpe(in) + args <- varKinds.toList.traverse(varKindFromProto) + } yield Type.exists(args, inT) case Value.TypeApply(TypeApply(left, right, _)) => (tpe(left), tpe(right)).mapN(Type.TyApply(_, _)) @@ -521,26 +517,24 @@ object ProtoConverter { val foralls = q.forallList val exs = q.existList val in = q.in - (foralls.traverse { case (b, _) => getId(b.name) }, - exs.traverse { case (b, _) => getId(b.name) }, + (foralls.traverse { case (b, k) => varKindToProto(b, k) }, + exs.traverse { case (b, k) => varKindToProto(b, k) }, typeToProto(in)) .flatMapN { (faids, exids, idx) => val ft0 = if (exs.nonEmpty) { - val eks = exs.map { case (_, k) => kindToProto(k) } val withEx = Type.exists(exs, in) getTypeId(withEx, proto.Type( - Value.TypeExists(TypeExists(exids, eks, idx)))) + Value.TypeExists(TypeExists(exids, idx)))) } else tabPure(idx) ft0.flatMap { t0 => if (foralls.nonEmpty) { - val fks = foralls.map { case (_, k) => kindToProto(k) } getTypeId(p, proto.Type( - Value.TypeForAll(TypeForAll(faids, fks, t0)))) + Value.TypeForAll(TypeForAll(faids, t0)))) } else tabPure(t0) } @@ -689,6 +683,11 @@ object ProtoConverter { } } + def varKindToProto(v: Type.Var.Bound, k: Kind): Tab[proto.VarKind] = + getId(v.name).map { id => + proto.VarKind(id, Some(kindToProto(k))) + } + def typedExprToProto(te: TypedExpr[Any]): Tab[Int] = StateT.get[Try, SerState] .map(_.expressions.indexOf(te)) @@ -699,14 +698,10 @@ object ProtoConverter { te match { case g@Generic(quant, expr) => val fas = quant.forallList.traverse { case (v, k) => - getId(v.name).map { id => - proto.VarKind(id, Some(kindToProto(k))) - } + varKindToProto(v, k) } val exs = quant.existList.traverse { case (v, k) => - getId(v.name).map { id => - proto.VarKind(id, Some(kindToProto(k))) - } + varKindToProto(v, k) } (fas, exs, typedExprToProto(expr)) .flatMapN { (fas, exs, exid) => diff --git a/core/src/main/scala/org/bykn/bosatsu/TypeParser.scala b/core/src/main/scala/org/bykn/bosatsu/TypeParser.scala index f00817a8b..6c4c3a3f9 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TypeParser.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TypeParser.scala @@ -36,12 +36,12 @@ abstract class TypeParser[A] { lowerIdent ~ kindP.? } - val univLike: P[(NonEmptyList[(String, Option[Kind])], A) => A] = + val quantified: P[(NonEmptyList[(String, Option[Kind])], A) => A] = keySpace("forall").as(universal(_, _)) | keySpace("exists").as(existential(_, _)) val lambda: P[MaybeTupleOrParens[A]] = - (univLike, univItem.nonEmptyListOfWs(maybeSpacesAndLines) ~ (maybeSpacesAndLines *> P.char('.') *> maybeSpacesAndLines *> recurse)) + (quantified, univItem.nonEmptyListOfWs(maybeSpacesAndLines) ~ (maybeSpacesAndLines *> P.char('.') *> maybeSpacesAndLines *> recurse)) .mapN { case (fn, (args, e)) => MaybeTupleOrParens.Bare(fn(args, e)) } val tupleOrParens: P[MaybeTupleOrParens[A]] = diff --git a/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala b/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala index e8910db1e..3f91586e0 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala @@ -266,7 +266,7 @@ object TypedExpr { case _ => None } - private[this] var emptyBound: SortedSet[Type.Var.Bound] = + private[this] val emptyBound: SortedSet[Type.Var.Bound] = SortedSet.empty implicit class InvariantTypedExpr[A](val self: TypedExpr[A]) extends AnyVal { @@ -458,9 +458,7 @@ object TypedExpr { .mapN { (typeArgs, r) => val forAlls = typeArgs.collect { case (nk, false) => nk } val exists = typeArgs.collect { case (nk, true) => nk } - val q = quantVars(forallList = forAlls, existList = exists, r) - //println(s"forAlls = $forAlls exists = $exists ${r.repr} quantVars = ${q.repr}") - q + quantVars(forallList = forAlls, existList = exists, r) } } @@ -470,7 +468,6 @@ object TypedExpr { for { envTypeVars <- getMetaTyVars(envList) localMetas = metas.diff(envTypeVars) - //_ = println(s"localMetas = $localMetas in ${te.repr}") q <- quantify0(localMetas.toList, te) } yield q } @@ -479,6 +476,7 @@ object TypedExpr { // this is lazy because we only evaluate it if there is an existential skolem lazy val envList = env.toList lazy val envExistSkols = Type.freeTyVars(envList) + .iterator .collect { case ex @ Skolem(_, _, true, _) => ex } @@ -510,12 +508,6 @@ object TypedExpr { getMetaTyVars(te1.allTypes.toList) .flatMap(quantifyMetas(envList, _, te1)) - /* - .map { res => - println(s"quantifyFree, teSkols=${teSkols} ${te.repr} => ${te1.repr} => ${res.repr}") - res - } - */ } /* @@ -530,7 +522,6 @@ object TypedExpr { def deepQuantify(env: Set[Type], te: TypedExpr[A]): F[TypedExpr[A]] = quantifyFree(env, te).flatMap { case Generic(quant, in) => - //assert(te != in, s"${te.repr} quantifyFree => ${in.repr}") deepQuantify(env + te.getType, in).map { in1 => quantVars(quant.forallList, quant.existList, in1) } @@ -603,10 +594,10 @@ object TypedExpr { deepQuantify(env1, arg).map(Match(_, branches, tag)) case Generic(quants, expr) => finish(expr).map(quantVars(quants.forallList, quants.existList, _)) + // $COVERAGE-OFF$ case unreach => - // $COVERAGE-OFF$ sys.error(s"Match quantification yielded neither Generic nor Match: $unreach") - // $COVERAGE-ON$ + // $COVERAGE-ON$ } noArg.flatMap(finish) diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala index aae40477d..e5da169e2 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala @@ -245,7 +245,10 @@ object Infer { private object Impl { sealed abstract class Expected[A] object Expected { - case class Inf[A](ref: Ref[Either[Error.InferIncomplete, A]]) extends Expected[A] + case class Inf[A](ref: Ref[Either[Error.InferIncomplete, A]]) extends Expected[A] { + def set(a: A): Infer[Unit] = + Infer.lift(ref.set(Right(a))) + } case class Check[A](value: A) extends Expected[A] } @@ -659,13 +662,6 @@ object Infer { } yield TypedExpr.coerceRho(t1, ck) // TODO this coerce seems right, since we have unified }) - def setInf(inf: Expected.Inf[(Type.Rho, Region)], rho: Type.Rho, r: Region): Infer[Unit] = - lift(inf.ref.update[Infer[Unit]] { - case Left(_) => (Right((rho, r)), unit) - case right@Right((rho1, r1)) => (right, unify(rho1, rho, r1, r)) - }) - .flatten - /* * Invariant: if the second argument is (Check rho) then rho is in weak prenex form */ @@ -677,7 +673,7 @@ object Infer { case infer@Expected.Inf(_) => for { (exSkols, rho) <- instantiate(sigma) - _ <- setInf(infer, rho, r) + _ <- infer.set((rho, r)) ks <- checkedKinds coerce = TypedExpr.coerceRho(rho, ks) } yield coerce.andThen(unskolemizeExists(exSkols)) @@ -1136,7 +1132,7 @@ object Infer { } typedBodyTpe <- extendEnvList(nameVarsT.toList)(inferRho(result)) (typedBody, bodyT) = typedBodyTpe - _ <- setInf(infer, Type.Fun(nameVarsT.map(_._2), bodyT), region(term)) + _ <- infer.set((Type.Fun(nameVarsT.map(_._2), bodyT), region(term))) } yield TypedExpr.AnnotatedLambda(nameVarsT, typedBody, tag) } case Let(name, rhs, body, isRecursive, tag) => @@ -1238,29 +1234,40 @@ object Infer { // are missing here. inferSigma(term) .flatMap { tsigma => - assertNoFree(tsigma.getType, s"line 1160 from $term\n\n${tsigma.repr}") - val check = Expected.Check((tsigma.getType, region(term))) expect match { case Expected.Check((resT, _)) => for { - env <- getEnv - unknownExs <- unsolvedExistentials(resT :: env.values.toList) - tbranches <- branches.parTraverse { case (p, r) => - // note, resT is in weak-prenex form, so this call is permitted - checkBranch(p, check, r, resT) - .product(solvedExistentitals(unknownExs).map((_, region(r)))) - } - _ <- unifyBranchExistentials(unknownExs, tbranches.map(_._2)) - } yield TypedExpr.Match(tsigma, tbranches.map(_._1), tag) + rest <- envTail + unknownExs <- unsolvedExistentials(resT :: rest) + tbranches <- + if (unknownExs.isEmpty) { + // in the common case there are no existentials save effort + branches.parTraverse { case (p, r) => + // note, resT is in weak-prenex form, so this call is permitted + checkBranch(p, check, r, resT) + } + } + else { + branches.parTraverse { case (p, r) => + // note, resT is in weak-prenex form, so this call is permitted + checkBranch(p, check, r, resT) + .product(solvedExistentitals(unknownExs).map((_, region(r)))) + } + .flatMap { tbranches => + unifyBranchExistentials(unknownExs, tbranches.map(_._2)) + .as(tbranches.map(_._1)) + } + } + } yield TypedExpr.Match(tsigma, tbranches, tag) case infer@Expected.Inf(_) => for { tbranches <- branches.parTraverse { case (p, r) => inferBranch(p, check, r) } (rho, regRho, resBranches) <- widenBranches(tbranches) - _ <- setInf(infer, rho, regRho) + _ <- infer.set((rho, regRho)) } yield TypedExpr.Match(tsigma, resBranches, tag) } } @@ -1352,7 +1359,6 @@ object Infer { * TODO: Pattern needs to have a region for each part */ def typeCheckPattern(pat: Pattern, sigma: Expected.Check[(Type, Region)], reg: Region): Infer[(Pattern, List[(Bindable, Type)])] = { - assertNoFree(sigma.value._1, "in typeCheckPattern line 1266") pat match { case GenPattern.WildCard => Infer.pure((pat, Nil)) case GenPattern.Literal(lit) => @@ -1415,7 +1421,6 @@ object Infer { Infer.pure((l, (splice, lst) :: Nil)) case ListPart.Item(p) => // This is a non-splice - assertNoFree(inner, s"line 1329 checkItem($inner, $lst, $e)") checkPat(p, inner, reg).map { case (p, l) => (ListPart.Item(p), l) } } val tpeOfList: Infer[Type] = @@ -1445,7 +1450,6 @@ object Infer { // like in the case of an annotation, we check the type, then // instantiate a sigma type // checkSigma(term, tpe) *> instSigma(tpe, expect) - assertNoFree(tpe, s"line 1359 $pat $reg") for { patBind <- checkPat(p, tpe, reg) (p1, binds) = patBind @@ -1460,7 +1464,6 @@ object Infer { // but we don't want to error type-checking since we want to show // the maximimum number of errors to the user envs <- args.zip(params).parTraverse { case (p, t) => - assertNoFree(t, s"pat: $pat line 1374") checkPat(p, t, reg) } pats = envs.map(_._1) @@ -1572,8 +1575,9 @@ object Infer { } pushDownCovariant(rest, nextRFA, left) match { case Type.ForAll(bs, l) => + // TODO: I think we can push down existentials too Type.forAll(bs, Type.TyApply(l, nextRight)) - case rho: Type.Rho => + case rho /*: Type.Rho */ => Type.TyApply(rho, nextRight) } case (_ :: rest, Type.TyApply(left, right)) => @@ -1583,8 +1587,9 @@ object Infer { Type.forAll(keptRight.reverse, pushDownCovariant(rest, lefts, left)) match { case Type.ForAll(bs, l) => + // TODO: we could possibly have an existential here? Type.forAll(bs, Type.TyApply(l, right)) - case rho: Type.Rho => + case rho /*: Type.Rho */=> Type.TyApply(rho, right) } case _ => @@ -1647,7 +1652,6 @@ object Infer { e <- env zrho <- zonkTypedExpr(rho) q <- TypedExpr.quantify(e, zrho, readMeta _, writeMeta _) - _ = assertNoFree(q.getType, s"line 1563 from quantify(${rho.repr}) => ${q.repr}") } yield q // allocate this once and reuse @@ -1734,13 +1738,6 @@ object Infer { case Left(err) => fail(err) } } yield (expr, tpe) - - def assertNoFree(t: Type, msg: => String): Unit = - Type.freeBoundTyVars(t :: Nil) match { - case Nil => () - case nel => - sys.error(s"expected no free vars in $t, found: $nel\n$msg") - } } private def recursiveTypeCheck[A: HasRegion](name: Bindable, expr: Expr[A]): Infer[TypedExpr[A]] = @@ -1837,9 +1834,7 @@ object Infer { .flatMap { groupChain => val glist = groupChain.toList extendEnvListPack(pack, glist.map { case (b, _, te) => - val t = te.getType - assertNoFree(t, s"line 1722 $b => ${te.repr}") - (b, t) + (b, te.getType) }) { run(tail) } diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala index fd7fba236..1fa6d30cb 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala @@ -140,7 +140,7 @@ object Type { new Order[Quantified] { def compare(a: Quantified, b: Quantified): Int = { val c = Order[Quantification].compare(a.quant, b.quant) - if (c == 0) Order[Type].compare(a.in, b.in) + if (c == 0) Order[Rho].compare(a.in, b.in) else c } } @@ -152,11 +152,23 @@ object Type { case class TyMeta(toMeta: Meta) extends Leaf def sameType(left: Type, right: Type): Boolean = - if (left.isInstanceOf[Leaf] && right.isInstanceOf[Leaf]) { - left == right - } - else { - normalize(left) == normalize(right) + left match { + case leftLeaf: Leaf => + // a Leaf is never equal to TyApply + right match { + case rightLeaf: Leaf => leftLeaf == rightLeaf + case _: TyApply => false + case q: Quantified => leftLeaf == normalize(q) + } + case _: TyApply => + if (right.isInstanceOf[Leaf]) false + else { + // left and right are not leafs + normalize(left) == normalize(right) + } + case _ => + // this is the quantified case + normalize(left) == normalize(right) } implicit val typeOrder: Order[Type] = @@ -350,7 +362,7 @@ object Type { val boundSet = q.vars.iterator.map(_._1).toSet[Type.Var] val env1 = env.iterator.filter { case (v, _) => !boundSet(v) }.toMap val subin = substituteVar(q.in, env1) - forAll(q.forallList, exists(q.existList, subin)) + quantify(q.quant, subin) }) def substituteRhoVar(t: Type.Rho, env: Map[Type.Var, Type.Rho]): Type.Rho = @@ -440,7 +452,7 @@ object Type { val forAllSize = fa1.size val normfas = newVars.take(forAllSize) val normexs = newVars.drop(forAllSize) - forAll(normfas, Type.exists(normexs, normin)) + quantify(forallList = normfas, existList = normexs, normin) case TyApply(on, arg) => TyApply(normalize(on), normalize(arg)) case _ => tpe } @@ -486,7 +498,6 @@ object Type { } case q: Quantified => val varList = q.vars.toList - require(varList.size == varList.toMap.size, s"invalid q: $q") rec((q.in, locals ++ varList)) } } diff --git a/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala b/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala index ee859d9d9..bef60a3c1 100644 --- a/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala +++ b/core/src/test/scala/org/bykn/bosatsu/TestUtils.scala @@ -46,11 +46,6 @@ object TestUtils { case t@Type.TyVar(Type.Var.Skolem(_, _, _, _)) => sys.error(s"illegal skolem ($t) escape in ${te.repr}") case Type.TyVar(Type.Var.Bound(_)) => t - /* - this doesn't work correctly with traverseType - if (bound(b)) t - else sys.error(s"unbound var: $b in ${te.repr}") - */ case t@Type.TyMeta(_) => sys.error(s"illegal meta ($t) escape in ${te.repr}") case Type.TyApply(left, right) => diff --git a/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala b/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala index 8dcbc1f17..865e32620 100644 --- a/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala @@ -882,4 +882,12 @@ x = ( assert(te.map(fn) == te.traverse[cats.Id, Int](fn)) } } + + test("freeTyVars is a superset of the frees in the outer type") { + forAll(genTypedExpr) { te => + assert(Type.freeTyVars(te.getType :: Nil).toSet.subsetOf( + te.freeTyVars.toSet + )) + } + } } diff --git a/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala b/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala index 5c8bcf970..15ba5e19c 100644 --- a/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/rankn/RankNInferTest.scala @@ -205,6 +205,8 @@ class RankNInferTest extends AnyFunSuite { // exists a. a is a top type assert_:<:("Int", "exists a. a") assert_:<:("(exists a. a) -> Int", "exists a. (a -> Int)") + assert_:<:("(exists a. a) -> Int", "Int -> Int") + assertTypesUnify("(exists a. a) -> Int", "forall a. a -> Int") assertTypesUnify("exists a. List[a]", "List[exists a. a]") assertTypesUnify("Int -> (exists a. a)", "exists a. (Int -> a)")