Skip to content

Commit

Permalink
checkpoint, but List sort is broken by the changes
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek committed Sep 23, 2023
1 parent 2248a6c commit a21ef20
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 85 deletions.
14 changes: 9 additions & 5 deletions core/src/main/scala/org/bykn/bosatsu/Expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,14 @@ object Expr {
}
}

private[bosatsu] def nameIterator(): Iterator[Bindable] =
Type
.allBinders
.iterator
.map(_.name)
.map(Identifier.Name(_))


def buildPatternLambda[A](
args: NonEmptyList[Pattern[(PackageName, Constructor), Type]],
body: Expr[A],
Expand All @@ -340,11 +348,7 @@ object Expr {
* compute this once if needed, which is why it is lazy.
* we don't want to traverse body if it is never needed
*/
lazy val anons = Type
.allBinders
.iterator
.map(_.name)
.map(Identifier.Name(_))
lazy val anons = nameIterator()
.filterNot(allNames(body) ++ args.patternNames)

type P = Pattern[(PackageName, Constructor), Type]
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/scala/org/bykn/bosatsu/ListUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,14 @@ private[bosatsu] object ListUtil {
if (bs eq as) nel
else NonEmptyList.fromListUnsafe(bs)
}

@annotation.tailrec
def find[X, Y](ls: List[X])(fn: X => Option[Y]): Option[Y] =
ls match {
case Nil => None
case h :: t => fn(h) match {
case None => find(t)(fn)
case some => some
}
}
}
217 changes: 137 additions & 80 deletions core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,21 @@ object TypedExprNormalization {
loop(scope1, tail, (b, r, normTE) :: acc)
}

loop(emptyScope, lets, Nil)
val res = loop(emptyScope, lets, Nil)
if (pack == PackageName.parts("Bosatsu", "List")) {
lets.zip(res).foreach { case ((b, _, in), (_, _, out)) =>
val same = in == out
println(s"$b was ${if (same) "not " else ""} changed")
println("=======init========")
println(in.repr)
if (!same) {
println("======final========")
println(out.repr)
}
println("===================")
}
}
res
}

def normalizeProgram[A, V](
Expand All @@ -75,7 +89,7 @@ object TypedExprNormalization {
/**
* if the te is not in normal form, transform it into normal form
*/
def normalizeLetOpt[A, V](namerec: Option[Bindable], te: TypedExpr[A], scope: Scope[A], typeEnv: TypeEnv[V])(implicit ev: V <:< Kind.Arg): Option[TypedExpr[A]] = {
private def normalizeLetOpt[A, V](namerec: Option[Bindable], te: TypedExpr[A], scope: Scope[A], typeEnv: TypeEnv[V])(implicit ev: V <:< Kind.Arg): Option[TypedExpr[A]] = {
val kindOf: Type => Option[Kind] =
{ case const @ Type.TyConst(_) =>
typeEnv.getType(const).map(_.kindOf)
Expand Down Expand Up @@ -124,6 +138,8 @@ object TypedExprNormalization {
Some(e1)
case (gen@Generic(_, _), rho: Type.Rho) =>
val inst = TypedExpr.instantiateTo(gen, rho, kindOf)
// we compare thes to te because instantiate
// can add an Annotation back
if (inst != te) Some(inst)
else None
case (notSameTpe, _) =>
Expand All @@ -136,7 +152,35 @@ object TypedExprNormalization {
}

case AnnotatedLambda(lamArgs0, expr, tag) =>
val lamArgs = lamArgs0.map { case (n, tpe0) => n -> Type.normalize(tpe0) }
lazy val anons: Iterator[Bindable] = Expr.nameIterator()
.filterNot(expr.freeVarsDup.toSet)

val e1 = normalize1(None, expr, scope, typeEnv).get

var changed = false
val lamArgs = lamArgs0.map { case (n, t) =>
val n1 =
if (e1.notFree(n)) {
// n is not used.
val next = anons.next()
changed = changed || (next != n)
next
}
else {
n
}
(n1, Type.normalize(t))
}

if (changed) {
normalize1(namerec,
AnnotatedLambda(lamArgs, e1, tag),
scope,
typeEnv)

}
else {

def doesntUseArgs(te: TypedExpr[A]): Boolean =
lamArgs.forall { case (n, _) => te.notFree(n) }

Expand All @@ -150,13 +194,18 @@ object TypedExprNormalization {
case _ => false
}

// we can normalize the arg to the smallest non-free var
// x -> f(x) == f (eta conversion)
// x -> generic(g) = generic(x -> g) if the type of x doesn't have free types with vars
val e1 = normalize1(None, expr, scope, typeEnv).get
val ws = Impl.WithScope(scope, ev.substituteCo[TypeEnv](typeEnv))
e1 match {
case App(fn, aargs, _, _) if matchesArgs(aargs) && doesntUseArgs(fn) =>
// x -> f(x) == f (eta conversion)
normalize1(None, setType(fn, te.getType), scope, typeEnv)
case App(ws.ResolveToLambda(Nil, args1, body, ftag), aargs, resT, tag) =>
// args -> (args1 -> e1)(...)
// this is inlining, which we do only when nested directly inside another lambda
val fn1 = AnnotatedLambda(args1, body, ftag)
normalize1(namerec,
AnnotatedLambda(lamArgs, App(fn1, aargs, resT, tag), tag),
scope, typeEnv)
case Let(arg1, ex, in, rec, tag1) if doesntUseArgs(ex) && doesntShadow(arg1) =>
// x ->
// y = z
Expand Down Expand Up @@ -187,18 +236,19 @@ object TypedExprNormalization {
if ((notApp eq expr) && (lamArgs === lamArgs0)) None
else Some(AnnotatedLambda(lamArgs, notApp, tag))
}
}
case Literal(_, _, _) =>
// these are fundamental
None
case Global(p, n: Constructor, tpe0, tag) =>
val tpe = Type.normalize(tpe0)
if (tpe == tpe0) None
else Some(Global(p, n, tpe, tag))
case Literal(_, _, _) =>
// these are fundamental
None
case Global(p, n: Bindable, tpe0, tag) =>
scope.getGlobal(p, n).flatMap {

case (RecursionKind.NonRecursive, te, _) if Impl.isSimple(te, lambdaSimple = false) =>
// TODO for a reason I don't understand, inlining lambdas here causes a stack overflow
// there is probably something somewhat unsound about this substitution that I don't understand
// inlining lambdas naively can cause an exponential blow up in size
Some(te)
case _ =>
val tpe = Type.normalize(tpe0)
Expand All @@ -222,9 +272,8 @@ object TypedExprNormalization {
normalize1(None, a, scope, typeEnv).get
}

val ws = Impl.WithScope(scope, ev.substituteCo[TypeEnv](typeEnv))
f1 match {
case ws.ResolveToLambda(Nil, lamArgs, expr, _) =>
case AnnotatedLambda(lamArgs, expr, _) =>
// (y -> z)(x) = let y = x in z
val lets = lamArgs.zip(args).map {
case ((n, ltpe), arg) => (n, setType(arg, ltpe))
Expand All @@ -236,7 +285,7 @@ object TypedExprNormalization {
// (app (let x y z) w) == (let x y (app z w)) if w does not have x free
normalize1(namerec, Let(arg1, ex, App(in, a1, tpe, tag), rec, tag1), scope, typeEnv)
case _ =>
if ((f1 eq fn) && (a1 eq args) && (tpe == tpe0)) None
if ((f1 eq fn) && (tpe == tpe0) && (a1 eq args)) None
else Some(App(f1, a1, tpe, tag))
}
case Let(arg, ex, in, rec, tag) =>
Expand Down Expand Up @@ -358,7 +407,9 @@ object TypedExprNormalization {
// does not depend on the arg
if (a1 eq arg) None
else Some(m1)
case Some(m2) if m2.size < m1.size =>
case Some(m2) =>
// TODO: we may not have a proof that m2 is smaller
// than m1. requiring m2.size < m1.size fails some tests
// we can possibly simplify this now:
normalize1(namerec, m2, scope, typeEnv)
case _ => None
Expand Down Expand Up @@ -494,19 +545,20 @@ object TypedExprNormalization {
case FnArgs(fn, args) =>
evaluate(fn, scope).map {
case EvalResult.Cons(p, c, ahead) => EvalResult.Cons(p, c, ahead ::: args.toList)
// $COVERAGE-OFF$
case EvalResult.Constant(c) =>
// this really shouldn't happen,
// $COVERAGE-OFF$
sys.error(s"unreachable: cannot apply a constant: $te => ${fn.repr} => $c")
// $COVERAGE-ON$
// $COVERAGE-ON$
}
case Global(pack, cons: Constructor, _, _) => Some(EvalResult.Cons(pack, cons, Nil))
case Global(pack, n: Bindable, _, _) =>
scope.getGlobal(pack, n).flatMap {
case (_, t, s) =>
case (RecursionKind.NonRecursive, t, s) =>
// Global values never have free values,
// so it is safe to substitute into our current scope
evaluate(t, s)
case _ => None
}
case Generic(_, in) =>
// if we can evaluate, we are okay
Expand Down Expand Up @@ -536,24 +588,33 @@ object TypedExprNormalization {
}

// The Option signals we can't complete
def expandMatches(br: Branch[A]): Option[List[Branch[A]]] =
br match {
case (ps@Pattern.PositionalStruct((p0, c0), args0), res) =>
if (p0 == p && c0 == c && args0.length == alen) Some((ps, res) :: Nil)
else Some(Nil)
case (Pattern.Named(n, p), res) =>
expandMatches((p, res)).map { bs =>
bs.map { case (bp, br) =>
(Pattern.Named(n, bp), br)
}
def filterPat(pat: Pat): Option[Option[Pat]] =
pat match {
case ps@Pattern.PositionalStruct((p0, c0), args0) =>
if (p0 == p && c0 == c && args0.length == alen) Some(Some(ps))
else Some(None) // we definitely don't match this branch
case Pattern.Named(n, p) =>
filterPat(p).map { p1 =>
p1.map { bp => Pattern.Named(n, bp) }
}
case (Pattern.Annotation(p, _), res) =>
case Pattern.Annotation(p, _) =>
// The annotation is only used at inference time, the values have already been typed
expandMatches((p, res))
case (Pattern.Union(h, t), r) =>
(h :: t.toList).traverse { p => expandMatches((p, r)) }.map(_.flatten)
case br@(p, _) if isTotal(p) => Some(br :: Nil)
case (Pattern.ListPat(_), _) =>
filterPat(p)
case Pattern.Union(h, t) =>
(filterPat(h), t.traverse(filterPat))
.mapN { (optP1, p2s) =>
val flatP2s: List[Pat] = p2s.toList.flatten
optP1 match {
case None =>
flatP2s match {
case Nil => None
case h :: t => Some(Pattern.union(h, t))
}
case Some(p1) => Some(Pattern.union(p1, flatP2s))
}
}
case Pattern.WildCard | Pattern.Var(_) => Some(Some(pat))
case Pattern.ListPat(_) =>
// TODO some of these patterns we could evaluate
None
case _ => None
Expand All @@ -571,42 +632,48 @@ object TypedExprNormalization {
}
}

m.branches.toList.traverse(expandMatches).map(_.flatten).flatMap {
case Nil =>
// TODO hitting this looks like a bug
// $COVERAGE-OFF$
//sys.error(s"no branch matched in ${m.repr} matched: $p::$c(${args.map(_.repr)})")
// $COVERAGE-ON$
None
case (MaybeNamedStruct(b, pats), r) :: rest if rest.isEmpty || pats.forall(isTotal) =>
// If there are no more items, or all inner patterns are total, we are done

// exactly one matches, this can be a sequential match
def matchAll(argPat: List[(TypedExpr[A], Pattern[(PackageName, Constructor), Type])]): TypedExpr[A] =
argPat match {
case Nil => r
case (a, p) :: tail =>
val tr = matchAll(tail)
p match {
case Pattern.WildCard =>
// we don't care about this value
tr
case Pattern.Var(b) =>
Let(b, a, tr, RecursionKind.NonRecursive, m.tag)
case _ =>
// This will get simplified later
Match(a, NonEmptyList.one((p, tr)), m.tag)
m.branches
.traverse { case (p, r) => filterPat(p).map((_, r)) }
// if we can check all the branches for a match, maybe we can evaluate
.flatMap { branches =>
val candidates: List[(Pat, TypedExpr[A])] =
branches.collect { case (Some(p), r) => (p, r)}

candidates match {
// $COVERAGE-OFF$
case Nil =>
// TODO hitting this looks like a bug
sys.error(s"no branch matched in ${m.repr} matched: $p::$c(${args.map(_.repr)})")
// $COVERAGE-ON$
case (MaybeNamedStruct(b, pats), r) :: rest if rest.isEmpty || pats.forall(isTotal) =>
// If there are no more items, or all inner patterns are total, we are done
// exactly one matches, this can be a sequential match
def matchAll(argPat: List[(TypedExpr[A], Pattern[(PackageName, Constructor), Type])]): TypedExpr[A] =
argPat match {
case Nil => r
case (a, p) :: tail =>
val tr = matchAll(tail)
p match {
case Pattern.WildCard =>
// we don't care about this value
tr
case Pattern.Var(b) =>
Let(b, a, tr, RecursionKind.NonRecursive, m.tag)
case _ =>
// This will get simplified later
Match(a, NonEmptyList.one((p, tr)), m.tag)
}
}
}

val res = matchAll(args.zip(pats))
Some(b.foldRight(res)(Let(_, m.arg, _, RecursionKind.NonRecursive, m.tag)))
case h :: t =>
// more than one branch might match, wait till runtime
val m1 = Match(m.arg, NonEmptyList(h, t), m.tag)
if (m1 == m) None
else Some(m1)
}
val res = matchAll(args.zip(pats))
Some(b.foldRight(res)(Let(_, m.arg, _, RecursionKind.NonRecursive, m.tag)))
case h :: t =>
// more than one branch might match, wait till runtime
val m1 = Match(m.arg, NonEmptyList(h, t), m.tag)
if (m1 == m) None
else Some(m1)
}
}

case EvalResult.Constant(li @ Lit.Integer(i)) =>
def makeLet(p: Pattern[(PackageName, Constructor), Type]): Option[List[Bindable]] =
Expand All @@ -626,17 +693,7 @@ object TypedExprNormalization {
// $COVERAGE-ON$
}

@annotation.tailrec
def find[X, Y](ls: List[X])(fn: X => Option[Y]): Option[Y] =
ls match {
case Nil => None
case h :: t => fn(h) match {
case None => find(t)(fn)
case some => some
}
}

find[Branch[A], TypedExpr[A]](m.branches.toList) { case (p, r) =>
ListUtil.find[Branch[A], TypedExpr[A]](m.branches.toList) { case (p, r) =>
makeLet(p).map { names =>
val lit = Literal[A](li, Type.getTypeOf(li), m.tag)
// all these names are bound to the lit
Expand Down

0 comments on commit a21ef20

Please sign in to comment.