Skip to content

Commit

Permalink
Hacking on TypedExprNormalization (#1050)
Browse files Browse the repository at this point in the history
* Hacking on TypedExprNormalization

* checkpoint, but List sort is broken by the changes

* Fix change tests to pass

* implement kindOf for TypedExprNormalization

* fix kindOf in WithScope

* fix scope bug in TypedExprNormalization

* Add specific regression test

* fix issue with SelfCallKind

* unify kindOf code

* re-enable test, cleanups

* some minor improvements

* try to improve test coverage
  • Loading branch information
johnynek authored Sep 30, 2023
1 parent 7b7f9fa commit 0fbec80
Show file tree
Hide file tree
Showing 10 changed files with 567 additions and 215 deletions.
14 changes: 9 additions & 5 deletions core/src/main/scala/org/bykn/bosatsu/Expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,14 @@ object Expr {
}
}

private[bosatsu] def nameIterator(): Iterator[Bindable] =
Type
.allBinders
.iterator
.map(_.name)
.map(Identifier.Name(_))


def buildPatternLambda[A](
args: NonEmptyList[Pattern[(PackageName, Constructor), Type]],
body: Expr[A],
Expand All @@ -340,11 +348,7 @@ object Expr {
* compute this once if needed, which is why it is lazy.
* we don't want to traverse body if it is never needed
*/
lazy val anons = Type
.allBinders
.iterator
.map(_.name)
.map(Identifier.Name(_))
lazy val anons = nameIterator()
.filterNot(allNames(body) ++ args.patternNames)

type P = Pattern[(PackageName, Constructor), Type]
Expand Down
6 changes: 6 additions & 0 deletions core/src/main/scala/org/bykn/bosatsu/ListUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,10 @@ private[bosatsu] object ListUtil {
case Some(nel) => greedyGroup(nel)(one)(combine).toList
}

def mapConserveNel[A <: AnyRef, B >: A <: AnyRef](nel: NonEmptyList[A])(f: A => B): NonEmptyList[B] = {
val as = nel.toList
val bs = as.mapConserve(f)
if (bs eq as) nel
else NonEmptyList.fromListUnsafe(bs)
}
}
14 changes: 9 additions & 5 deletions core/src/main/scala/org/bykn/bosatsu/SelfCallKind.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ object SelfCallKind {
val ifNoCallSemigroup: Semigroup[SelfCallKind] =
Semigroup.instance(_.ifNoCallThen(_))

private def isFn[A](n: Bindable, te: TypedExpr[A]): Boolean =
te match {
case TypedExpr.Generic(_, in) => isFn(n, in)
case TypedExpr.Annotation(te, _) => isFn(n, te)
case TypedExpr.Local(vn, _, _) => vn == n
case _ => false
}
/** assuming expr is bound to nm, what kind of self call does it contain?
*/
def apply[A](n: Bindable, te: TypedExpr[A]): SelfCallKind =
Expand All @@ -67,11 +74,8 @@ object SelfCallKind {
.reduce(SelfCallKind.ifNoCallSemigroup)

argsCall.ifNoCallThen(
fn match {
case TypedExpr.Local(vn, _, _) if vn == n =>
SelfCallKind.TailCall
case _ => apply(n, fn).callNotTail
}
if (isFn(n, fn)) SelfCallKind.TailCall
else apply(n, fn).callNotTail
)
case TypedExpr.Let(arg, ex, in, rec, _) =>
if (arg == n) {
Expand Down
214 changes: 162 additions & 52 deletions core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ sealed abstract class TypedExpr[+T] { self: Product =>
*/
lazy val getType: Type =
this match {
case Generic(params, expr) =>
Type.forAll(params, expr.getType)
case g@Generic(_, _) => g.forAllType
case Annotation(_, tpe) =>
tpe
case AnnotatedLambda(args, res, _) =>
Expand All @@ -47,6 +46,19 @@ sealed abstract class TypedExpr[+T] { self: Product =>
branches.head._2.getType
}

lazy val size: Int =
this match {
case Generic(_, g) => g.size
case Annotation(a, _) => a.size
case AnnotatedLambda(_, res, _) =>
res.size
case Local(_, _, _) | Literal(_, _, _) | Global(_, _, _, _) => 1
case App(fn, args, _, _) => fn.size + args.foldMap(_.size)
case Let(_, e, in, _, _) => e.size + in.size
case Match(a, branches, _) =>
a.size + branches.foldMap(_._2.size)
}

// TODO: we need to make sure this parsable and maybe have a mode that has the compiler
// emit these
def repr: String = {
Expand Down Expand Up @@ -174,6 +186,8 @@ object TypedExpr {
*/
case class Generic[T](typeVars: NonEmptyList[(Type.Var.Bound, Kind)], in: TypedExpr[T]) extends TypedExpr[T] {
def tag: T = in.tag

lazy val forAllType: Type.ForAll = Type.forAll(typeVars, in.getType)
}
// Annotation really means "widen", the term has a type that is a subtype of coerce, so we are widening
// to the given type. This happens on Locals/Globals also in their tpe
Expand Down Expand Up @@ -606,46 +620,148 @@ object TypedExpr {

type Coerce = FunctionK[TypedExpr, TypedExpr]

private def pushDownCovariant(tpe: Type, kinds: Type => Option[Kind]): Type = {
import Type._
tpe match {
case ForAll(targs, in) =>
val (cons, cargs) = Type.unapplyAll(in)
kinds(cons) match {
case None =>
// this can happen because the cons is some kind of type variable
// we have lost track of (we need to track the type variables in
// recursions)
tpe
case Some(kind) =>

val kindArgs = kind.toArgs
val kindArgsWithArgs = kindArgs.zip(cargs).map { case (ka, a) => (Some(ka), a) } :::
cargs.drop(kindArgs.length).map((None, _))

val argsVectorIdx = kindArgsWithArgs
.iterator
.zipWithIndex
.map { case ((optKA, tpe), idx) =>
(Type.freeBoundTyVars(tpe :: Nil).toSet, optKA, tpe, idx)
}
.toVector

// if an arg is covariant, it can pull all it's unique freeVars
def uniqueFreeVars(idx: Int): Set[Type.Var.Bound] = {
val (justIdx, optKA, _, _) = argsVectorIdx(idx)
if (optKA.exists(_.variance == Variance.co)) {
argsVectorIdx.iterator.filter(_._4 != idx)
.foldLeft(justIdx) { case (acc, (s, _, _, _)) => acc -- s }
}
else Set.empty
}
val withPulled = argsVectorIdx.map { case rec@(_, _, _, idx) =>
(rec, uniqueFreeVars(idx))
}
val allPulled: Set[Type.Var.Bound] = withPulled.foldMap(_._2)
val nonpulled = targs.filterNot { case (v, _) => allPulled(v) }
val pulledArgs = withPulled.iterator.map { case ((_, _, tpe, _), uniques) =>
val keep: Type.Var.Bound => Boolean = uniques
Type.forAll(targs.filter { case (t, _) => keep(t) }, tpe)
}
.toList
Type.forAll(nonpulled, Type.applyAll(cons, pulledArgs))
}

case _ => tpe
}
}

// We know initTpe <:< instTpe, we may be able to simply
// fix some of the universally quantified variables
private def instantiateTo[A](gen: Generic[A], instTpe: Type.Rho, kinds: Type => Option[Kind]): Option[TypedExpr[A]] =
gen.getType match {
case Type.ForAll(bs, in) =>
import Type._
def solve(left: Type, right: Type, state: Map[Type.Var, Type], solveSet: Set[Type.Var]): Option[Map[Type.Var, Type]] =
(left, right) match {
case (TyVar(v), right) if solveSet(v) =>
Some(state.updated(v, right))
case (ForAll(b, i), r) =>
// this will mask solving for the inside values:
solve(i, r, state, solveSet -- b.toList.iterator.map(_._1))
case (_, ForAll(_, _)) =>
// TODO:
// if cons is covariant, all the free params of arg
// not in cons can pushed into arg
None
case (TyApply(on, arg), TyApply(on2, arg2)) =>
for {
s1 <- solve(on, on2, state, solveSet)
s2 <- solve(arg, arg2, s1, solveSet)
} yield s2
case (TyConst(_) | TyMeta(_) | TyVar(_), _) =>
if (left == right) {
// can't recurse further into left
Some(state)
}
else None
case (TyApply(_, _), _) => None
def instantiateTo[A](gen: Generic[A], instTpe: Type.Rho, kinds: Type => Option[Kind]): TypedExpr[A] = {
import Type._

def solve(left: Type,
right: Type,
state: Map[Type.Var, Type],
solveSet: Set[Type.Var],
varKinds: Map[Type.Var, Kind]): Option[Map[Type.Var, Type]] =
(left, right) match {
case (TyVar(v), right) if solveSet(v) =>
Some(state.updated(v, right))
case (fa @ ForAll(b, i), r) =>
if (fa.sameAs(r)) Some(state)
// this will mask solving for the inside values:
else solve(i,
r,
state,
solveSet -- b.toList.iterator.map(_._1),
varKinds ++ b.toList
)
case (_, fa@ForAll(_, _)) =>
val kindsWithVars: Type => Option[Kind] =
{
case v: Type.Var => varKinds.get(v)
case t => kinds(t)
}
val fa1 = pushDownCovariant(fa, kindsWithVars)
if (fa1 != fa) solve(left, fa1, state, solveSet, varKinds)
else {
// not clear what to do here,
// the examples that come up look like un-unified
// types, as if coerceRho is called before we have
// finished unifying
None
}
case (TyApply(on, arg), TyApply(on2, arg2)) =>
for {
s1 <- solve(on, on2, state, solveSet, varKinds)
s2 <- solve(arg, arg2, s1, solveSet, varKinds)
} yield s2
case (TyConst(_) | TyMeta(_) | TyVar(_), _) =>
if (left == right) {
// can't recurse further into left
Some(state)
}
else None
case (TyApply(_, _), _) => None
}

val solveSet: Set[Var] = bs.toList.iterator.map(_._1).toSet
solve(in, instTpe, Map.empty, solveSet)
.flatMap { subs =>
if (subs.keySet == solveSet) Some(substituteTypeVar(gen.in, subs))
else None
val Type.ForAll(bs, in) = gen.forAllType
val solveSet: Set[Var] = bs.toList.iterator.map(_._1).toSet

val result =
solve(in, instTpe, Map.empty, solveSet, bs.toList.toMap)
.map { subs =>
val freeVars = solveSet -- subs.keySet
val subBody = substituteTypeVar(gen.in, subs)
val freeTypeVars = gen.typeVars.filter { case (t, _) => freeVars(t) }
NonEmptyList.fromList(freeTypeVars) match {
case None => subBody
case Some(frees) =>
val newGen = Generic(frees, subBody)
pushGeneric(newGen) match {
case badOpt @ (None | Some(Generic(_, _)))=>
// just wrap
Annotation(badOpt.getOrElse(newGen), instTpe)
case Some(notGen) => notGen
}
}
case _ => None
}

result match {
case None =>
// TODO some of these just don't look fully unified yet, for instance:
// could not solve instantiate:
//
// forall b: *. Bosatsu/Predef::Order[b] -> forall a: *. Bosatsu/Predef::Dict[b, a]
//
// to
//
// Bosatsu/Predef::Order[?338] -> Bosatsu/Predef::Dict[$k$303, $v$304]
// but those two types aren't the same. It seems like we have to later
// learn that ?338 == $k$303, but we don't seem to know that yet

// just add an annotation:
Annotation(gen, instTpe)
case Some(res) => res
}
}

private def allPatternTypes[N](p: Pattern[N, Type]): SortedSet[Type] =
p.traverseType { t => Writer[SortedSet[Type], Type](SortedSet(t), t) }.run._1
Expand All @@ -654,13 +770,15 @@ object TypedExpr {
g.in match {
case AnnotatedLambda(args, body, a) =>
val argFree = Type.freeBoundTyVars(args.toList.map(_._2)).toSet
if (g.typeVars.exists { case (b, _) => argFree(b) }) {
None
}
else {
val gbody = Generic(g.typeVars, body)
val (outer, inner) = g.typeVars.toList.partition { case (b, _) => argFree(b) }
NonEmptyList.fromList(inner).map { inner =>
val gbody = Generic(inner, body)
val pushedBody = pushGeneric(gbody).getOrElse(gbody)
Some(AnnotatedLambda(args, pushedBody, a))
val lam = AnnotatedLambda(args, gbody, a)
NonEmptyList.fromList(outer) match {
case None => lam
case Some(outer) => forAll(outer, lam)
}
}
// we can do the same thing on Match
case Match(arg, branches, tag) =>
Expand Down Expand Up @@ -719,12 +837,7 @@ object TypedExpr {
pushGeneric(gen) match {
case Some(e1) => self(e1)
case None =>
instantiateTo(gen, tpe, kinds) match {
case Some(res) => res
case None =>
// TODO: this is basically giving up
Annotation(gen, tpe)
}
instantiateTo(gen, tpe, kinds)
}
case App(fn, aargs, _, tag) =>
fn match {
Expand Down Expand Up @@ -959,10 +1072,7 @@ object TypedExpr {
pushGeneric(gen) match {
case Some(e1) => self(e1)
case None =>
instantiateTo(gen, fntpe, kinds) match {
case Some(res) => res
case None => Annotation(gen, fntpe)
}
instantiateTo(gen, fntpe, kinds)
}
case Local(_, _, _) | Global(_, _, _, _) | Literal(_, _, _) =>
Annotation(expr, fntpe)
Expand Down
Loading

0 comments on commit 0fbec80

Please sign in to comment.