Skip to content

Commit

Permalink
Fix long bug in error regions (#1049)
Browse files Browse the repository at this point in the history
* Fix long bug in error regions

* add some more test coverage

* add ExprTest
  • Loading branch information
johnynek authored Sep 20, 2023
1 parent 423a144 commit cef4c90
Show file tree
Hide file tree
Showing 10 changed files with 791 additions and 385 deletions.
247 changes: 145 additions & 102 deletions core/src/main/scala/org/bykn/bosatsu/Expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,115 @@ package org.bykn.bosatsu

import cats.implicits._
import cats.data.{Chain, Writer, NonEmptyList}
import cats.{Applicative, Eval, Traverse}
import cats.Applicative
import scala.collection.immutable.SortedSet
import org.bykn.bosatsu.rankn.Type

import Identifier.{Bindable, Constructor}

sealed abstract class Expr[T] {
def tag: T

/**
* All the free variables in this expression in order
* encountered and with duplicates (to see how often
* they appear)
*/
lazy val freeVarsDup: List[Bindable] = {
import Expr._
// nearly identical code to TypedExpr.freeVarsDup, bugs should be fixed in both places
this match {
case Generic(_, expr) =>
expr.freeVarsDup
case Annotation(t, _, _) =>
t.freeVarsDup
case Local(ident, _) =>
ident :: Nil
case Global(_, _, _) =>
Nil
case Lambda(args, res, _) =>
val nameSet = args.toList.iterator.map(_._1).toSet
ListUtil.filterNot(res.freeVarsDup)(nameSet)
case App(fn, args, _) =>
fn.freeVarsDup ::: args.reduceMap(_.freeVarsDup)
case Let(arg, argE, in, rec, _) =>
val argFree0 = argE.freeVarsDup
val argFree =
if (rec.isRecursive) {
ListUtil.filterNot(argFree0)(_ === arg)
}
else argFree0

argFree ::: (ListUtil.filterNot(in.freeVarsDup)(_ === arg))
case Literal(_, _) =>
Nil
case Match(arg, branches, _) =>
val argFree = arg.freeVarsDup

val branchFrees = branches.toList.map { case (p, b) =>
// these are not free variables in this branch
val newBinds = p.names.toSet
val bfree = b.freeVarsDup
if (newBinds.isEmpty) bfree
else ListUtil.filterNot(bfree)(newBinds)
}
// we can only take one branch, so count the max on each branch:
val branchFreeMax = branchFrees
.zipWithIndex
.flatMap { case (names, br) => names.map((_, br)) }
// these groupBys are okay because we sort at the end
.groupBy(identity) // group-by-name x branch
.map { case ((name, branch), names) => (names.length, branch, name) }
.groupBy(_._3) // group by just the name now
.toList
.flatMap { case (_, vs) =>
val (cnt, branch, name) = vs.maxBy(_._1)
List.fill(cnt)((branch, name))
}
.sorted
.map(_._2)

argFree ::: branchFreeMax
}
}

lazy val globals: Set[Expr.Global[T]] = {
import Expr._
this match {
case Generic(_, expr) =>
expr.globals
case Annotation(t, _, _) =>
t.globals
case Local(_, _) => Set.empty
case g @ Global(_, _, _) => Set.empty + g
case Lambda(_, res, _) => res.globals
case App(fn, args, _) =>
fn.globals | args.reduceMap(_.globals)
case Let(_, argE, in, _, _) =>
argE.globals | in.globals
case Literal(_, _) => Set.empty
case Match(arg, branches, _) =>
arg.globals | branches.reduceMap { case (_, b) => b.globals }
}
}

def replaceTag(t: T): Expr[T] = {
import Expr._
this match {
case g@Generic(_, e) => g.copy(in = e.replaceTag(t))
case a@Annotation(_, _, _) => a.copy(tag = t)
case l@Local(_, _) => l.copy(tag = t)
case g @ Global(_, _, _) => g.copy(tag = t)
case l@Lambda(_, _, _) => l.copy(tag = t)
case a@App(_, _, _) => a.copy(tag = t)
case l@Let(_, _, _, _, _) => l.copy(tag = t)
case l@Literal(_, _) => l.copy(tag = t)
case m@Match(_, _, _) => m.copy(tag = t)
}
}

def notFree(b: Bindable): Boolean =
!freeVarsDup.contains(b)
}

object Expr {
Expand All @@ -28,10 +129,52 @@ object Expr {
case class Global[T](pack: PackageName, name: Identifier, tag: T) extends Name[T]
case class App[T](fn: Expr[T], args: NonEmptyList[Expr[T]], tag: T) extends Expr[T]
case class Lambda[T](args: NonEmptyList[(Bindable, Option[Type])], expr: Expr[T], tag: T) extends Expr[T]
case class Let[T](arg: Bindable, expr: Expr[T], in: Expr[T], recursive: RecursionKind, tag: T) extends Expr[T]
case class Let[T](arg: Bindable, expr: Expr[T], in: Expr[T], recursive: RecursionKind, tag: T) extends Expr[T] {
def flatten: (NonEmptyList[(Bindable, RecursionKind, Expr[T], T)], Expr[T]) = {
val thisLet = (arg, recursive, expr, tag)

in match {
case let@Let(_, _, _, _, _) =>
val (lets, finalIn) = let.flatten
(thisLet :: lets, finalIn)
case _ =>
// this is the final let
(NonEmptyList.one(thisLet), in)
}
}
}
case class Literal[T](lit: Lit, tag: T) extends Expr[T]
case class Match[T](arg: Expr[T], branches: NonEmptyList[(Pattern[(PackageName, Constructor), Type], Expr[T])], tag: T) extends Expr[T]

// Inverse of `Let.flatten`
def lets[T](binds: List[(Bindable, RecursionKind, Expr[T], T)], in: Expr[T]): Expr[T] =
binds match {
case Nil => in
case (b, r, e, t) :: tail =>
val res = lets(tail, in)
Let(b, e, res, r, t)
}

object Annotated {
def unapply[A](expr: Expr[A]): Option[Type] =
expr match {
case Annotation(_, tpe, _) => Some(tpe)
case Lambda(args, Annotated(res), _) =>
args.traverse { case (_, ot) => ot }
.map { argTpes =>
Type.Fun(argTpes, res)
}
case Literal(lit, _) => Some(Type.getTypeOf(lit))
case Let(_, _, Annotated(t), _, _) => Some(t)
case Match(_, branches, _) =>
branches.traverse { case (_, expr) => unapply(expr) }
.flatMap { allAnnotated =>
if (allAnnotated.tail.forall(_ === allAnnotated.head)) Some(allAnnotated.head)
else None
}
case _ => None
}
}

def forAll[A](tpeArgs: List[(Type.Var.Bound, Kind)], expr: Expr[A]): Expr[A] =
NonEmptyList.fromList(tpeArgs) match {
Expand Down Expand Up @@ -188,106 +331,6 @@ object Expr {
}
}

/*
* We have seen some intermitten CI failures if this isn't lazy
* presumably due to initialiazation order
*/
implicit lazy val exprTraverse: Traverse[Expr] =
new Traverse[Expr] {

// Traverse on NonEmptyList[(Pattern[_], Expr[?])]
private lazy val tne = {
type Tup[T] = (Pattern[(PackageName, Constructor), Type], T)
type TupExpr[T] = (Pattern[(PackageName, Constructor), Type], Expr[T])
val tup: Traverse[TupExpr] = Traverse[Tup].compose(exprTraverse)
Traverse[NonEmptyList].compose(tup)
}

def traverse[G[_]: Applicative, A, B](fa: Expr[A])(f: A => G[B]): G[Expr[B]] =
fa match {
case Annotation(e, tpe, a) =>
(e.traverse(f), f(a)).mapN(Annotation(_, tpe, _))
case Local(s, t) =>
f(t).map(Local(s, _))
case Global(p, s, t) =>
f(t).map(Global(p, s, _))
case Generic(bs, e) =>
traverse(e)(f).map(Generic(bs, _))
case App(fn, args, t) =>
(fn.traverse(f), args.traverse(_.traverse(f)), f(t)).mapN { (fn1, a1, b) =>
App(fn1, a1, b)
}
case Lambda(args, expr, t) =>
(expr.traverse(f), f(t)).mapN { (e1, t1) =>
Lambda(args, e1, t1)
}
case Let(arg, exp, in, rec, tag) =>
(exp.traverse(f), in.traverse(f), f(tag)).mapN { (e1, i1, t1) =>
Let(arg, e1, i1, rec, t1)
}
case Literal(lit, tag) =>
f(tag).map(Literal(lit, _))
case Match(arg, branches, tag) =>
val argB = arg.traverse(f)
val branchB = tne.traverse(branches)(f)
(argB, branchB, f(tag)).mapN { (a, bs, t) =>
Match(a, bs, t)
}
}

def foldLeft[A, B](fa: Expr[A], b: B)(f: (B, A) => B): B =
fa match {
case Annotation(e, _, tag) =>
val b1 = foldLeft(e, b)(f)
f(b1, tag)
case n: Name[A] => f(b, n.tag)
case App(fn, args, tag) =>
val b1 = foldLeft(fn, b)(f)
val b2 = args.foldLeft(b1) { (b1, x) => foldLeft(x, b1)(f) }
f(b2, tag)
case Generic(_, in) => foldLeft(in, b)(f)
case Lambda(_, expr, tag) =>
val b1 = foldLeft(expr, b)(f)
f(b1, tag)
case Let(_, exp, in, _, tag) =>
val b1 = foldLeft(exp, b)(f)
val b2 = foldLeft(in, b1)(f)
f(b2, tag)
case Literal(_, tag) =>
f(b, tag)
case Match(arg, branches, tag) =>
val b1 = foldLeft(arg, b)(f)
val b2 = tne.foldLeft(branches, b1)(f)
f(b2, tag)
}

def foldRight[A, B](fa: Expr[A], lb: Eval[B])(f: (A, Eval[B]) => Eval[B]): Eval[B] =
fa match {
case Annotation(e, _, tag) =>
val lb1 = foldRight(e, lb)(f)
f(tag, lb1)
case n: Name[A] => f(n.tag, lb)
case App(fn, args, tag) =>
val b1 = f(tag, lb)
val b2 = args.foldRight(b1)((a, b1) => foldRight(a, b1)(f))
foldRight(fn, b2)(f)
case Generic(_, in) => foldRight(in, lb)(f)
case Lambda(_, expr, tag) =>
val b1 = f(tag, lb)
foldRight(expr, b1)(f)
case Let(_, exp, in, _, tag) =>
val b1 = f(tag, lb)
val b2 = foldRight(in, b1)(f)
foldRight(exp, b2)(f)
case Literal(_, tag) =>
f(tag, lb)
case Match(arg, branches, tag) =>
val b1 = f(tag, lb)
val b2 = tne.foldRight(branches, b1)(f)
foldRight(arg, b2)(f)
}
}

def buildPatternLambda[A](
args: NonEmptyList[Pattern[(PackageName, Constructor), Type]],
body: Expr[A],
Expand Down
41 changes: 41 additions & 0 deletions core/src/main/scala/org/bykn/bosatsu/ListUtil.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package org.bykn.bosatsu

import cats.data.NonEmptyList

private[bosatsu] object ListUtil {
// filter b from a pretty short lst but try to conserve lst if possible
def filterNot[A](lst: List[A])(b: A => Boolean): List[A] =
lst match {
case Nil => lst
case h :: tail =>
val t1 = filterNot(tail)(b)
if (b(h)) t1
else if (t1 eq tail) lst
else (h :: t1) // we only allocate here
}

def greedyGroup[A, G](list: NonEmptyList[A])(one: A => G)(combine: (G, A) => Option[G]): NonEmptyList[G] = {
def loop(g: G, tail: List[A]): NonEmptyList[G] =
tail match {
case Nil => NonEmptyList.one(g)
case tailh :: tailt =>
combine(g, tailh) match {
case None =>
// can't combine into the head group, start a new group
g :: loop(one(tailh), tailt)
case Some(g1) =>
// we can combine into a new group
loop(g1, tailt)
}
}

loop(one(list.head), list.tail)
}

def greedyGroup[A, G](list: List[A])(one: A => G)(combine: (G, A) => Option[G]): List[G] =
NonEmptyList.fromList(list) match {
case None => Nil
case Some(nel) => greedyGroup(nel)(one)(combine).toList
}

}
Loading

0 comments on commit cef4c90

Please sign in to comment.