Skip to content

Commit

Permalink
Hacking on TypedExprNormalization
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek committed Sep 20, 2023
1 parent cef4c90 commit 2248a6c
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 77 deletions.
6 changes: 6 additions & 0 deletions core/src/main/scala/org/bykn/bosatsu/ListUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,10 @@ private[bosatsu] object ListUtil {
case Some(nel) => greedyGroup(nel)(one)(combine).toList
}

def mapConserveNel[A <: AnyRef](nel: NonEmptyList[A])(f: A => A): NonEmptyList[A] = {
val as = nel.toList
val bs = as.mapConserve(f)
if (bs eq as) nel
else NonEmptyList.fromListUnsafe(bs)
}
}
205 changes: 153 additions & 52 deletions core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ sealed abstract class TypedExpr[+T] { self: Product =>
*/
lazy val getType: Type =
this match {
case Generic(params, expr) =>
Type.forAll(params, expr.getType)
case g@Generic(_, _) => g.forAllType
case Annotation(_, tpe) =>
tpe
case AnnotatedLambda(args, res, _) =>
Expand All @@ -47,6 +46,19 @@ sealed abstract class TypedExpr[+T] { self: Product =>
branches.head._2.getType
}

lazy val size: Int =
this match {
case Generic(_, g) => g.size
case Annotation(a, _) => a.size
case AnnotatedLambda(_, res, _) =>
res.size
case Local(_, _, _) | Literal(_, _, _) | Global(_, _, _, _) => 1
case App(fn, args, _, _) => fn.size + args.foldMap(_.size)
case Let(_, e, in, _, _) => e.size + in.size
case Match(a, branches, _) =>
a.size + branches.foldMap(_._2.size)
}

// TODO: we need to make sure this parsable and maybe have a mode that has the compiler
// emit these
def repr: String = {
Expand Down Expand Up @@ -174,6 +186,8 @@ object TypedExpr {
*/
case class Generic[T](typeVars: NonEmptyList[(Type.Var.Bound, Kind)], in: TypedExpr[T]) extends TypedExpr[T] {
def tag: T = in.tag

lazy val forAllType: Type.ForAll = Type.forAll(typeVars, in.getType)
}
// Annotation really means "widen", the term has a type that is a subtype of coerce, so we are widening
// to the given type. This happens on Locals/Globals also in their tpe
Expand Down Expand Up @@ -606,46 +620,139 @@ object TypedExpr {

type Coerce = FunctionK[TypedExpr, TypedExpr]

private def pushDownCovariant(tpe: Type, kinds: Type => Option[Kind]): Type = {
import Type._
tpe match {
case ForAll(targs, in) =>
val (cons, cargs) = Type.unapplyAll(in)
kinds(cons) match {
case None =>
sys.error(s"unknown kind of $cons in $tpe")
case Some(kind) =>

val kindArgs = kind.toArgs
val kindArgsWithArgs = kindArgs.zip(cargs).map { case (ka, a) => (Some(ka), a) } :::
cargs.drop(kindArgs.length).map((None, _))

val argsVectorIdx = kindArgsWithArgs
.iterator
.zipWithIndex
.map { case ((optKA, tpe), idx) =>
(Type.freeBoundTyVars(tpe :: Nil).toSet, optKA, tpe, idx)
}
.toVector

// if an arg is covariant, it can pull all it's unique freeVars
def uniqueFreeVars(idx: Int): Set[Type.Var.Bound] = {
val (justIdx, optKA, _, _) = argsVectorIdx(idx)
if (optKA.exists(_.variance == Variance.co)) {
argsVectorIdx.iterator.filter(_._4 != idx)
.foldLeft(justIdx) { case (acc, (s, _, _, _)) => acc -- s }
}
else Set.empty
}
val withPulled = argsVectorIdx.map { case rec@(_, _, _, idx) =>
(rec, uniqueFreeVars(idx))
}
val allPulled: Set[Type.Var.Bound] = withPulled.foldMap(_._2)
val nonpulled = targs.filterNot { case (v, _) => allPulled(v) }
val pulledArgs = withPulled.iterator.map { case ((_, _, tpe, _), uniques) =>
val keep: Type.Var.Bound => Boolean = uniques
Type.forAll(targs.filter { case (t, _) => keep(t) }, tpe)
}
.toList
Type.forAll(nonpulled, Type.applyAll(cons, pulledArgs))
}

case _ => tpe
}
}

// We know initTpe <:< instTpe, we may be able to simply
// fix some of the universally quantified variables
private def instantiateTo[A](gen: Generic[A], instTpe: Type.Rho, kinds: Type => Option[Kind]): Option[TypedExpr[A]] =
gen.getType match {
case Type.ForAll(bs, in) =>
import Type._
def solve(left: Type, right: Type, state: Map[Type.Var, Type], solveSet: Set[Type.Var]): Option[Map[Type.Var, Type]] =
(left, right) match {
case (TyVar(v), right) if solveSet(v) =>
Some(state.updated(v, right))
case (ForAll(b, i), r) =>
// this will mask solving for the inside values:
solve(i, r, state, solveSet -- b.toList.iterator.map(_._1))
case (_, ForAll(_, _)) =>
// TODO:
// if cons is covariant, all the free params of arg
// not in cons can pushed into arg
None
case (TyApply(on, arg), TyApply(on2, arg2)) =>
for {
s1 <- solve(on, on2, state, solveSet)
s2 <- solve(arg, arg2, s1, solveSet)
} yield s2
case (TyConst(_) | TyMeta(_) | TyVar(_), _) =>
if (left == right) {
// can't recurse further into left
Some(state)
}
else None
case (TyApply(_, _), _) => None
def instantiateTo[A](gen: Generic[A], instTpe: Type.Rho, kinds: Type => Option[Kind]): TypedExpr[A] = {
import Type._

/*
def show(t: Type): String =
Type.fullyResolvedDocument.document(t).render(80)
*/

def solve(left: Type, right: Type, state: Map[Type.Var, Type], solveSet: Set[Type.Var]): Option[Map[Type.Var, Type]] =
(left, right) match {
case (TyVar(v), right) if solveSet(v) =>
Some(state.updated(v, right))
case (ForAll(b, i), r) =>
// this will mask solving for the inside values:
solve(i, r, state, solveSet -- b.toList.iterator.map(_._1))
case (_, fa@ForAll(_, _)) =>
val fa1 = pushDownCovariant(fa, kinds)
if (fa1 != fa) solve(left, fa1, state, solveSet)
else {
// not clear what to do here,
// the examples that come up look like un-unified
// types, as if coerceRho is called before we have
// finished unifying
//
//println(s"could not pushDown: ${show(fa)}")
None
}
case (TyApply(on, arg), TyApply(on2, arg2)) =>
for {
s1 <- solve(on, on2, state, solveSet)
s2 <- solve(arg, arg2, s1, solveSet)
} yield s2
case (TyConst(_) | TyMeta(_) | TyVar(_), _) =>
if (left == right) {
// can't recurse further into left
Some(state)
}
else None
case (TyApply(_, _), _) => None
}

val solveSet: Set[Var] = bs.toList.iterator.map(_._1).toSet
solve(in, instTpe, Map.empty, solveSet)
.flatMap { subs =>
if (subs.keySet == solveSet) Some(substituteTypeVar(gen.in, subs))
else None
val Type.ForAll(bs, in) = gen.forAllType
val solveSet: Set[Var] = bs.toList.iterator.map(_._1).toSet

val result =
solve(in, instTpe, Map.empty, solveSet)
.map { subs =>
val freeVars = solveSet -- subs.keySet
val subBody = substituteTypeVar(gen.in, subs)
val freeTypeVars = gen.typeVars.filter { case (t, _) => freeVars(t) }
NonEmptyList.fromList(freeTypeVars) match {
case None => subBody
case Some(frees) =>
val newGen = Generic(frees, subBody)
pushGeneric(newGen) match {
case badOpt @ (None | Some(Generic(_, _)))=>
// just wrap
//println(s"could not push frees instantiate: ${show(gen.getType)} to ${show(instTpe)}\n\n${badOpt.map(_.repr)}")
Annotation(badOpt.getOrElse(newGen), instTpe)
case Some(notGen) => notGen
}
}
case _ => None
}

result match {
case None =>
// TODO some of these just don't look fully unified yet, for instance:
// could not solve instantiate:
//
// forall b: *. Bosatsu/Predef::Order[b] -> forall a: *. Bosatsu/Predef::Dict[b, a]
//
// to
//
// Bosatsu/Predef::Order[?338] -> Bosatsu/Predef::Dict[$k$303, $v$304]
// but those two types aren't the same. It seems like we have to later
// learn that ?338 == $k$303, but we don't seem to know that yet

//println(s"could not solve instantiate: ${show(gen.getType)} to ${show(instTpe)}")
// just add an annotation:
Annotation(gen, instTpe)
case Some(res) => res
}
}

private def allPatternTypes[N](p: Pattern[N, Type]): SortedSet[Type] =
p.traverseType { t => Writer[SortedSet[Type], Type](SortedSet(t), t) }.run._1
Expand All @@ -654,13 +761,15 @@ object TypedExpr {
g.in match {
case AnnotatedLambda(args, body, a) =>
val argFree = Type.freeBoundTyVars(args.toList.map(_._2)).toSet
if (g.typeVars.exists { case (b, _) => argFree(b) }) {
None
}
else {
val gbody = Generic(g.typeVars, body)
val (outer, inner) = g.typeVars.toList.partition { case (b, _) => argFree(b) }
NonEmptyList.fromList(inner).map { inner =>
val gbody = Generic(inner, body)
val pushedBody = pushGeneric(gbody).getOrElse(gbody)
Some(AnnotatedLambda(args, pushedBody, a))
val lam = AnnotatedLambda(args, gbody, a)
NonEmptyList.fromList(outer) match {
case None => lam
case Some(outer) => forAll(outer, lam)
}
}
// we can do the same thing on Match
case Match(arg, branches, tag) =>
Expand Down Expand Up @@ -719,12 +828,7 @@ object TypedExpr {
pushGeneric(gen) match {
case Some(e1) => self(e1)
case None =>
instantiateTo(gen, tpe, kinds) match {
case Some(res) => res
case None =>
// TODO: this is basically giving up
Annotation(gen, tpe)
}
instantiateTo(gen, tpe, kinds)
}
case App(fn, aargs, _, tag) =>
fn match {
Expand Down Expand Up @@ -959,10 +1063,7 @@ object TypedExpr {
pushGeneric(gen) match {
case Some(e1) => self(e1)
case None =>
instantiateTo(gen, fntpe, kinds) match {
case Some(res) => res
case None => Annotation(gen, fntpe)
}
instantiateTo(gen, fntpe, kinds)
}
case Local(_, _, _) | Global(_, _, _, _) | Literal(_, _, _) =>
Annotation(expr, fntpe)
Expand Down
Loading

0 comments on commit 2248a6c

Please sign in to comment.