diff --git a/core/src/main/scala/org/bykn/bosatsu/ListUtil.scala b/core/src/main/scala/org/bykn/bosatsu/ListUtil.scala index 22af4a10f..018d3475a 100644 --- a/core/src/main/scala/org/bykn/bosatsu/ListUtil.scala +++ b/core/src/main/scala/org/bykn/bosatsu/ListUtil.scala @@ -38,4 +38,10 @@ private[bosatsu] object ListUtil { case Some(nel) => greedyGroup(nel)(one)(combine).toList } + def mapConserveNel[A <: AnyRef](nel: NonEmptyList[A])(f: A => A): NonEmptyList[A] = { + val as = nel.toList + val bs = as.mapConserve(f) + if (bs eq as) nel + else NonEmptyList.fromListUnsafe(bs) + } } \ No newline at end of file diff --git a/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala b/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala index bdb33df7d..8b707f17e 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala @@ -29,8 +29,7 @@ sealed abstract class TypedExpr[+T] { self: Product => */ lazy val getType: Type = this match { - case Generic(params, expr) => - Type.forAll(params, expr.getType) + case g@Generic(_, _) => g.forAllType case Annotation(_, tpe) => tpe case AnnotatedLambda(args, res, _) => @@ -47,6 +46,19 @@ sealed abstract class TypedExpr[+T] { self: Product => branches.head._2.getType } + lazy val size: Int = + this match { + case Generic(_, g) => g.size + case Annotation(a, _) => a.size + case AnnotatedLambda(_, res, _) => + res.size + case Local(_, _, _) | Literal(_, _, _) | Global(_, _, _, _) => 1 + case App(fn, args, _, _) => fn.size + args.foldMap(_.size) + case Let(_, e, in, _, _) => e.size + in.size + case Match(a, branches, _) => + a.size + branches.foldMap(_._2.size) + } + // TODO: we need to make sure this parsable and maybe have a mode that has the compiler // emit these def repr: String = { @@ -174,6 +186,8 @@ object TypedExpr { */ case class Generic[T](typeVars: NonEmptyList[(Type.Var.Bound, Kind)], in: TypedExpr[T]) extends TypedExpr[T] { def tag: T = in.tag + + lazy val forAllType: Type.ForAll = Type.forAll(typeVars, in.getType) } // Annotation really means "widen", the term has a type that is a subtype of coerce, so we are widening // to the given type. This happens on Locals/Globals also in their tpe @@ -606,46 +620,139 @@ object TypedExpr { type Coerce = FunctionK[TypedExpr, TypedExpr] + private def pushDownCovariant(tpe: Type, kinds: Type => Option[Kind]): Type = { + import Type._ + tpe match { + case ForAll(targs, in) => + val (cons, cargs) = Type.unapplyAll(in) + kinds(cons) match { + case None => + sys.error(s"unknown kind of $cons in $tpe") + case Some(kind) => + + val kindArgs = kind.toArgs + val kindArgsWithArgs = kindArgs.zip(cargs).map { case (ka, a) => (Some(ka), a) } ::: + cargs.drop(kindArgs.length).map((None, _)) + + val argsVectorIdx = kindArgsWithArgs + .iterator + .zipWithIndex + .map { case ((optKA, tpe), idx) => + (Type.freeBoundTyVars(tpe :: Nil).toSet, optKA, tpe, idx) + } + .toVector + + // if an arg is covariant, it can pull all it's unique freeVars + def uniqueFreeVars(idx: Int): Set[Type.Var.Bound] = { + val (justIdx, optKA, _, _) = argsVectorIdx(idx) + if (optKA.exists(_.variance == Variance.co)) { + argsVectorIdx.iterator.filter(_._4 != idx) + .foldLeft(justIdx) { case (acc, (s, _, _, _)) => acc -- s } + } + else Set.empty + } + val withPulled = argsVectorIdx.map { case rec@(_, _, _, idx) => + (rec, uniqueFreeVars(idx)) + } + val allPulled: Set[Type.Var.Bound] = withPulled.foldMap(_._2) + val nonpulled = targs.filterNot { case (v, _) => allPulled(v) } + val pulledArgs = withPulled.iterator.map { case ((_, _, tpe, _), uniques) => + val keep: Type.Var.Bound => Boolean = uniques + Type.forAll(targs.filter { case (t, _) => keep(t) }, tpe) + } + .toList + Type.forAll(nonpulled, Type.applyAll(cons, pulledArgs)) + } + + case _ => tpe + } + } + // We know initTpe <:< instTpe, we may be able to simply // fix some of the universally quantified variables - private def instantiateTo[A](gen: Generic[A], instTpe: Type.Rho, kinds: Type => Option[Kind]): Option[TypedExpr[A]] = - gen.getType match { - case Type.ForAll(bs, in) => - import Type._ - def solve(left: Type, right: Type, state: Map[Type.Var, Type], solveSet: Set[Type.Var]): Option[Map[Type.Var, Type]] = - (left, right) match { - case (TyVar(v), right) if solveSet(v) => - Some(state.updated(v, right)) - case (ForAll(b, i), r) => - // this will mask solving for the inside values: - solve(i, r, state, solveSet -- b.toList.iterator.map(_._1)) - case (_, ForAll(_, _)) => - // TODO: - // if cons is covariant, all the free params of arg - // not in cons can pushed into arg - None - case (TyApply(on, arg), TyApply(on2, arg2)) => - for { - s1 <- solve(on, on2, state, solveSet) - s2 <- solve(arg, arg2, s1, solveSet) - } yield s2 - case (TyConst(_) | TyMeta(_) | TyVar(_), _) => - if (left == right) { - // can't recurse further into left - Some(state) - } - else None - case (TyApply(_, _), _) => None + def instantiateTo[A](gen: Generic[A], instTpe: Type.Rho, kinds: Type => Option[Kind]): TypedExpr[A] = { + import Type._ + + /* + def show(t: Type): String = + Type.fullyResolvedDocument.document(t).render(80) + */ + + def solve(left: Type, right: Type, state: Map[Type.Var, Type], solveSet: Set[Type.Var]): Option[Map[Type.Var, Type]] = + (left, right) match { + case (TyVar(v), right) if solveSet(v) => + Some(state.updated(v, right)) + case (ForAll(b, i), r) => + // this will mask solving for the inside values: + solve(i, r, state, solveSet -- b.toList.iterator.map(_._1)) + case (_, fa@ForAll(_, _)) => + val fa1 = pushDownCovariant(fa, kinds) + if (fa1 != fa) solve(left, fa1, state, solveSet) + else { + // not clear what to do here, + // the examples that come up look like un-unified + // types, as if coerceRho is called before we have + // finished unifying + // + //println(s"could not pushDown: ${show(fa)}") + None + } + case (TyApply(on, arg), TyApply(on2, arg2)) => + for { + s1 <- solve(on, on2, state, solveSet) + s2 <- solve(arg, arg2, s1, solveSet) + } yield s2 + case (TyConst(_) | TyMeta(_) | TyVar(_), _) => + if (left == right) { + // can't recurse further into left + Some(state) } + else None + case (TyApply(_, _), _) => None + } - val solveSet: Set[Var] = bs.toList.iterator.map(_._1).toSet - solve(in, instTpe, Map.empty, solveSet) - .flatMap { subs => - if (subs.keySet == solveSet) Some(substituteTypeVar(gen.in, subs)) - else None + val Type.ForAll(bs, in) = gen.forAllType + val solveSet: Set[Var] = bs.toList.iterator.map(_._1).toSet + + val result = + solve(in, instTpe, Map.empty, solveSet) + .map { subs => + val freeVars = solveSet -- subs.keySet + val subBody = substituteTypeVar(gen.in, subs) + val freeTypeVars = gen.typeVars.filter { case (t, _) => freeVars(t) } + NonEmptyList.fromList(freeTypeVars) match { + case None => subBody + case Some(frees) => + val newGen = Generic(frees, subBody) + pushGeneric(newGen) match { + case badOpt @ (None | Some(Generic(_, _)))=> + // just wrap + //println(s"could not push frees instantiate: ${show(gen.getType)} to ${show(instTpe)}\n\n${badOpt.map(_.repr)}") + Annotation(badOpt.getOrElse(newGen), instTpe) + case Some(notGen) => notGen + } } - case _ => None + } + + result match { + case None => + // TODO some of these just don't look fully unified yet, for instance: + // could not solve instantiate: + // + // forall b: *. Bosatsu/Predef::Order[b] -> forall a: *. Bosatsu/Predef::Dict[b, a] + // + // to + // + // Bosatsu/Predef::Order[?338] -> Bosatsu/Predef::Dict[$k$303, $v$304] + // but those two types aren't the same. It seems like we have to later + // learn that ?338 == $k$303, but we don't seem to know that yet + + //println(s"could not solve instantiate: ${show(gen.getType)} to ${show(instTpe)}") + // just add an annotation: + Annotation(gen, instTpe) + case Some(res) => res } + } private def allPatternTypes[N](p: Pattern[N, Type]): SortedSet[Type] = p.traverseType { t => Writer[SortedSet[Type], Type](SortedSet(t), t) }.run._1 @@ -654,13 +761,15 @@ object TypedExpr { g.in match { case AnnotatedLambda(args, body, a) => val argFree = Type.freeBoundTyVars(args.toList.map(_._2)).toSet - if (g.typeVars.exists { case (b, _) => argFree(b) }) { - None - } - else { - val gbody = Generic(g.typeVars, body) + val (outer, inner) = g.typeVars.toList.partition { case (b, _) => argFree(b) } + NonEmptyList.fromList(inner).map { inner => + val gbody = Generic(inner, body) val pushedBody = pushGeneric(gbody).getOrElse(gbody) - Some(AnnotatedLambda(args, pushedBody, a)) + val lam = AnnotatedLambda(args, gbody, a) + NonEmptyList.fromList(outer) match { + case None => lam + case Some(outer) => forAll(outer, lam) + } } // we can do the same thing on Match case Match(arg, branches, tag) => @@ -719,12 +828,7 @@ object TypedExpr { pushGeneric(gen) match { case Some(e1) => self(e1) case None => - instantiateTo(gen, tpe, kinds) match { - case Some(res) => res - case None => - // TODO: this is basically giving up - Annotation(gen, tpe) - } + instantiateTo(gen, tpe, kinds) } case App(fn, aargs, _, tag) => fn match { @@ -959,10 +1063,7 @@ object TypedExpr { pushGeneric(gen) match { case Some(e1) => self(e1) case None => - instantiateTo(gen, fntpe, kinds) match { - case Some(res) => res - case None => Annotation(gen, fntpe) - } + instantiateTo(gen, fntpe, kinds) } case Local(_, _, _) | Global(_, _, _, _) | Literal(_, _, _) => Annotation(expr, fntpe) diff --git a/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala b/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala index de199cd3a..1c3def394 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala @@ -36,7 +36,7 @@ object TypedExprNormalization { if (r.isRecursive) (Some(b), scope - b) else (None, scope) - def normalizeAll[A, V](pack: PackageName, lets: List[(Bindable, RecursionKind, TypedExpr[A])], typeEnv: TypeEnv[V]): List[(Bindable, RecursionKind, TypedExpr[A])] = { + def normalizeAll[A, V](pack: PackageName, lets: List[(Bindable, RecursionKind, TypedExpr[A])], typeEnv: TypeEnv[V])(implicit ev: V <:< Kind.Arg): List[(Bindable, RecursionKind, TypedExpr[A])] = { @annotation.tailrec def loop(scope: Scope[A], lets: List[(Bindable, RecursionKind, TypedExpr[A])], acc: List[(Bindable, RecursionKind, TypedExpr[A])]): List[(Bindable, RecursionKind, TypedExpr[A])] = lets match { @@ -55,7 +55,7 @@ object TypedExprNormalization { def normalizeProgram[A, V]( p: PackageName, fullTypeEnv: TypeEnv[V], - prog: Program[TypeEnv[V], TypedExpr[Declaration], A]): Program[TypeEnv[V], TypedExpr[Declaration], A] = { + prog: Program[TypeEnv[V], TypedExpr[Declaration], A])(implicit ev: V <:< Kind.Arg): Program[TypeEnv[V], TypedExpr[Declaration], A] = { val Program(typeEnv, lets, extDefs, stmts) = prog val normalLets = normalizeAll(p, lets, fullTypeEnv) Program(typeEnv, normalLets, extDefs, stmts) @@ -63,7 +63,7 @@ object TypedExprNormalization { // if you have made one step of progress, use this to recurse // so we don't throw away if we don't progress more - private def normalize1[A, V](namerec: Option[Bindable], te: TypedExpr[A], scope: Scope[A], typeEnv: TypeEnv[V]): Some[TypedExpr[A]] = + private def normalize1[A, V](namerec: Option[Bindable], te: TypedExpr[A], scope: Scope[A], typeEnv: TypeEnv[V])(implicit ev: V <:< Kind.Arg): Some[TypedExpr[A]] = normalizeLetOpt(namerec, te, scope, typeEnv) match { case None => Some(te) case s@Some(_) => s @@ -75,7 +75,13 @@ object TypedExprNormalization { /** * if the te is not in normal form, transform it into normal form */ - def normalizeLetOpt[A, V](namerec: Option[Bindable], te: TypedExpr[A], scope: Scope[A], typeEnv: TypeEnv[V]): Option[TypedExpr[A]] = + def normalizeLetOpt[A, V](namerec: Option[Bindable], te: TypedExpr[A], scope: Scope[A], typeEnv: TypeEnv[V])(implicit ev: V <:< Kind.Arg): Option[TypedExpr[A]] = { + val kindOf: Type => Option[Kind] = + { case const @ Type.TyConst(_) => + typeEnv.getType(const).map(_.kindOf) + case _ => None + } + te match { case g@Generic(_, Annotation(term, _)) if g.getType.sameAs(term.getType) => normalize1(namerec, term, scope, typeEnv) @@ -112,11 +118,15 @@ object TypedExprNormalization { // if we annotate twice, we can ignore the inner annotation // we should have type annotation where we normalize type parameters val e1 = normalize1(namerec, term, scope, typeEnv).get - e1 match { + (e1, tpe) match { case _ if e1.getType.sameAs(tpe) => // the type is already right Some(e1) - case notSameTpe => + case (gen@Generic(_, _), rho: Type.Rho) => + val inst = TypedExpr.instantiateTo(gen, rho, kindOf) + if (inst != te) Some(inst) + else None + case (notSameTpe, _) => val nt = Type.normalize(tpe) if (notSameTpe eq term) { if (nt == tpe) None @@ -208,10 +218,13 @@ object TypedExprNormalization { val f1 = normalize1(None, fn, scope, typeEnv).get // the second and third branches use this but the first doesn't // make it lazy so we don't recurse more than needed - lazy val a1 = args.map(normalize1(None, _, scope, typeEnv).get) - val ws = Impl.WithScope(scope) + lazy val a1 = ListUtil.mapConserveNel(args) { a => + normalize1(None, a, scope, typeEnv).get + } + + val ws = Impl.WithScope(scope, ev.substituteCo[TypeEnv](typeEnv)) f1 match { - case ws.ResolveToLambda(lamArgs, expr, _) => + case ws.ResolveToLambda(Nil, lamArgs, expr, _) => // (y -> z)(x) = let y = x in z val lets = lamArgs.zip(args).map { case ((n, ltpe), arg) => (n, setType(arg, ltpe)) @@ -223,7 +236,7 @@ object TypedExprNormalization { // (app (let x y z) w) == (let x y (app z w)) if w does not have x free normalize1(namerec, Let(arg1, ex, App(in, a1, tpe, tag), rec, tag1), scope, typeEnv) case _ => - if ((f1 eq fn) && (a1 == args) && (tpe == tpe0)) None + if ((f1 eq fn) && (a1 eq args) && (tpe == tpe0)) None else Some(App(f1, a1, tpe, tag)) } case Let(arg, ex, in, rec, tag) => @@ -345,9 +358,10 @@ object TypedExprNormalization { // does not depend on the arg if (a1 eq arg) None else Some(m1) - case Some(m2) => + case Some(m2) if m2.size < m1.size => // we can possibly simplify this now: normalize1(namerec, m2, scope, typeEnv) + case _ => None } } else { @@ -356,6 +370,7 @@ object TypedExprNormalization { normalize1(namerec, Match(a1, branches1a, tag), scope, typeEnv) } } + } def normalize[A](te: TypedExpr[A]): Option[TypedExpr[A]] = normalizeLetOpt(None, te, emptyScope, TypeEnv.empty) @@ -374,22 +389,39 @@ object TypedExprNormalization { } } - case class WithScope[A](scope: Scope[A]) { + case class WithScope[A](scope: Scope[A], typeEnv: TypeEnv[Kind.Arg]) { + private lazy val kindOf: Type => Option[Kind] = + { case const @ Type.TyConst(_) => + typeEnv.getType(const).map(_.kindOf) + case _ => None + } + object ResolveToLambda { - def unapply(te: TypedExpr[A]): Option[(NonEmptyList[(Bindable, Type)], TypedExpr[A], A)] = + def unapply(te: TypedExpr[A]): Option[(List[(Type.Var.Bound, Kind)], NonEmptyList[(Bindable, Type)], TypedExpr[A], A)] = te match { - case AnnotatedLambda(args, expr, ltag) => Some((args, expr, ltag)) + case Annotation(ResolveToLambda((h :: t), args, ex, tag), rho: Type.Rho) => + val asGen = + Generic(NonEmptyList(h, t), AnnotatedLambda(args, ex, tag)) + + TypedExpr.instantiateTo(asGen, rho, kindOf) match { + case AnnotatedLambda(a, e, t) => Some((Nil, a, e, t)) + case Generic(nel, AnnotatedLambda(a, e, t)) => Some((nel.toList, a, e, t)) + case _ => None + } + case Generic(frees, ResolveToLambda(f1, args, ex, tag)) => + Some((frees.toList ::: f1, args, ex, tag)) + case AnnotatedLambda(args, expr, ltag) => Some((Nil, args, expr, ltag)) case Global(p, n: Bindable, _, _) => scope.getGlobal(p, n).flatMap { case (RecursionKind.NonRecursive, te, scope1) => - val s1 = WithScope(scope1) + val s1 = WithScope(scope1, typeEnv) te match { - case s1.ResolveToLambda(args, expr, ltag) => + case s1.ResolveToLambda(frees, args, expr, ltag) => // we can't just replace variables if the scopes don't match. // we could also repair the scope by making a let binding // for any names that don't match (which has to be done recursively if (scopeMatches(expr.freeVarsDup.toSet -- args.iterator.map(_._1), scope, scope1)) { - Some((args, expr, ltag)) + Some((frees, args, expr, ltag)) } else None case _ => None @@ -399,14 +431,14 @@ object TypedExprNormalization { case Local(nm, _, _) => scope.getLocal(nm).flatMap { case (RecursionKind.NonRecursive, te, scope1) => - val s1 = WithScope(scope1) + val s1 = WithScope(scope1, typeEnv) te match { - case s1.ResolveToLambda(args, expr, ltag) => + case s1.ResolveToLambda(frees, args, expr, ltag) => // we can't just replace variables if the scopes don't match. // we could also repair the scope by making a let binding // for any names that don't match (which has to be done recursively if (scopeMatches(expr.freeVarsDup.toSet -- args.iterator.map(_._1), scope, scope1)) { - Some((args, expr, ltag)) + Some((frees, args, expr, ltag)) } else None case _ => None @@ -541,9 +573,11 @@ object TypedExprNormalization { m.branches.toList.traverse(expandMatches).map(_.flatten).flatMap { case Nil => + // TODO hitting this looks like a bug // $COVERAGE-OFF$ - sys.error(s"no branch matched in ${m.repr} matched: $p::$c(${args.map(_.repr)})") + //sys.error(s"no branch matched in ${m.repr} matched: $p::$c(${args.map(_.repr)})") // $COVERAGE-ON$ + None case (MaybeNamedStruct(b, pats), r) :: rest if rest.isEmpty || pats.forall(isTotal) => // If there are no more items, or all inner patterns are total, we are done diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala index 99293c750..ce5ebbfd6 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/Infer.scala @@ -442,7 +442,7 @@ object Infer { readMeta(m).flatMap { case None => pure(None) case Some(ty) => - Type.zonkRhoMeta(ty)(zonk(_)).flatMap { ty1 => + zonkRho(ty).flatMap { ty1 => // short out multiple hops (I guess an optimization?) // note: this meta was already written, so we know // the kind must match @@ -450,6 +450,9 @@ object Infer { } } + def zonkRho(rho: Type.Rho): Infer[Type.Rho] = + Type.zonkRhoMeta(rho)(zonk(_)) + /** * This fills in any meta vars that have been * quantified and replaces them with what they point to @@ -605,7 +608,10 @@ object Infer { } yield TypedExpr.coerceRho(ta, ks) case (t1, t2) => // rule: MONO - unify(t1, t2, left, right) *> checkedKinds.map(TypedExpr.coerceRho(t1, _)) // TODO this coerce seems right, since we have unified + for { + _ <- unify(t1, t2, left, right) + ck <- checkedKinds + } yield TypedExpr.coerceRho(t1, ck) // TODO this coerce seems right, since we have unified }) /* @@ -621,6 +627,7 @@ object Infer { rho <- instantiate(sigma, r) _ <- infer.set((rho, r)) ks <- checkedKinds + // there is no point in zonking here, we just instantiated rho } yield TypedExpr.coerceRho(rho, ks) } diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala index 01f078891..ee1a7af12 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/Type.scala @@ -113,7 +113,7 @@ object Type { case Some(ne) => forAll(ne, in) } - final def forAll(vars: NonEmptyList[(Var.Bound, Kind)], in: Type): Type = + final def forAll(vars: NonEmptyList[(Var.Bound, Kind)], in: Type): Type.ForAll = in match { case rho: Rho => Type.ForAll(vars, rho) case Type.ForAll(ne1, rho) => Type.ForAll(vars ::: ne1, rho) diff --git a/core/src/main/scala/org/bykn/bosatsu/rankn/TypeEnv.scala b/core/src/main/scala/org/bykn/bosatsu/rankn/TypeEnv.scala index 5239aa6f2..d61586d2a 100644 --- a/core/src/main/scala/org/bykn/bosatsu/rankn/TypeEnv.scala +++ b/core/src/main/scala/org/bykn/bosatsu/rankn/TypeEnv.scala @@ -44,6 +44,11 @@ class TypeEnv[+A] private ( def getType(p: PackageName, t: TypeName): Option[DefinedType[A]] = definedTypes.get((p, t)) + def getType(t: Type.TyConst): Option[DefinedType[A]] = { + val d = t.tpe.toDefined + getType(d.packageName, d.name) + } + def getExternalValue(p: PackageName, n: Identifier): Option[Type] = values.get((p, n)) diff --git a/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala b/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala index 164b0bade..d4a387f07 100644 --- a/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala @@ -522,7 +522,8 @@ res = y -> (x -> f(x, z))(y) checkLast(""" f = (_, y) -> y -res = y -> f(y, 1) +#res = y -> f(y, 1) +res = _ -> 1 """) { te2 => assert(te1.void == te2.void, s"${te1.repr} != ${te2.repr}") }