diff --git a/cli/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala b/cli/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala index 5498f543b..f635aa837 100644 --- a/cli/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala +++ b/cli/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala @@ -45,7 +45,7 @@ class ClangGenTest extends munit.FunSuite { To inspect the code, change the hash, and it will print the code out */ testFilesCompilesToHash("test_workspace/Ackermann.bosatsu")( - "260c81bc79b6232a3f174cb9afc04143" + "ccbf676b90cf04397c908d23f86b6434" ) } } \ No newline at end of file diff --git a/core/src/main/scala/org/bykn/bosatsu/Pattern.scala b/core/src/main/scala/org/bykn/bosatsu/Pattern.scala index 203dc68e7..d6761b34c 100644 --- a/core/src/main/scala/org/bykn/bosatsu/Pattern.scala +++ b/core/src/main/scala/org/bykn/bosatsu/Pattern.scala @@ -89,6 +89,36 @@ sealed abstract class Pattern[+N, +T] { Nil } + def substitute(table: Map[Bindable, Bindable]): Pattern[N, T] = + this match { + case Pattern.WildCard | Pattern.Literal(_) => this + case Pattern.Var(b) => + table.get(b) match { + case None => this + case Some(b1) => Pattern.Var(b1) + } + case Pattern.Named(n, p) => + val p1 = p.substitute(table) + val n2 = table.get(n) match { + case None => n + case Some(n1) => n1 + } + if ((p1 eq p) && (n2 eq n)) this + else Pattern.Named(n2, p1) + case Pattern.Annotation(p, t) => + val p1 = p.substitute(table) + if (p1 eq p) this + else Pattern.Annotation(p1, t) + case Pattern.Union(h, t) => + Pattern.Union(h.substitute(table), t.map(_.substitute(table))) + case Pattern.PositionalStruct(n, pats) => + Pattern.PositionalStruct(n, pats.map(_.substitute(table))) + case Pattern.ListPat(parts) => + Pattern.ListPat(parts.map(_.substitute(table))) + case Pattern.StrPat(parts) => + Pattern.StrPat(parts.map(_.substitute(table))) + } + /** List all the names that strictly smaller than anything that would match * this pattern e.g. a top level var, would not be returned */ @@ -312,7 +342,24 @@ object Pattern { extends NamedKind } - sealed abstract class StrPart + sealed abstract class StrPart { + import StrPart._ + + def substitute(table: Map[Bindable, Bindable]): StrPart = + this match { + case WildStr | LitStr(_) | WildChar => this + case NamedStr(n) => + table.get(n) match { + case None => this + case Some(n1) => NamedStr(n1) + } + case NamedChar(n) => + table.get(n) match { + case None => this + case Some(n1) => NamedChar(n1) + } + } + } object StrPart { final case object WildStr extends StrPart final case class NamedStr(name: Bindable) extends StrPart @@ -354,6 +401,22 @@ object Pattern { final case class Item[A](pat: A) extends ListPart[A] { def map[B](fn: A => B): ListPart[B] = Item(fn(pat)) } + + implicit class ListPartPat[N, T](val self: ListPart[Pattern[N, T]]) extends AnyVal { + def substitute(table: Map[Bindable, Bindable]): ListPart[Pattern[N, T]] = + self match { + case WildList => WildList + case NamedList(n) => + table.get(n) match { + case None => self + case Some(n1) => NamedList(n1) + } + case Item(p) => + val p1 = p.substitute(table) + if (p1 eq p) self + else Item(p1) + } + } } /** This will match any list without any binding diff --git a/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala b/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala index c932cf944..8ee2bc0b4 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TypedExpr.scala @@ -62,52 +62,54 @@ sealed abstract class TypedExpr[+T] { self: Product => def repr: Doc = { def rept(t: Type): Doc = Type.fullyResolvedDocument.document(t) + def block(d: Doc): Doc = d.nested(4).grouped + def loop(te: TypedExpr[T]): Doc = te match { case g @ Generic(_, expr) => - (Doc.text("(generic") + Doc.lineOrSpace + rept( + block(Doc.text("(generic") + Doc.line + rept( g.quantType - ) + Doc.lineOrSpace + loop(expr) + Doc.char(')')).nested(4) + ) + Doc.line + loop(expr) + Doc.char(')')) case Annotation(expr, tpe) => - (Doc.text("(ann") + Doc.lineOrSpace + rept( + block(Doc.text("(ann") + Doc.line + rept( tpe - ) + Doc.lineOrSpace + loop(expr) + Doc.char(')')).nested(4) + ) + Doc.line + loop(expr) + Doc.char(')')) case AnnotatedLambda(args, res, _) => - (Doc.text("(lambda") + Doc.lineOrSpace + ( - Doc.char('[') + Doc.intercalate( - Doc.lineOrSpace, + block(Doc.text("(lambda") + Doc.line + ( + Doc.char('[') + block(Doc.intercalate( + Doc.line, args.toList.map { case (arg, tpe) => - Doc.text(arg.sourceCodeRepr) + Doc.lineOrSpace + rept(tpe) + Doc.text(arg.sourceCodeRepr) + Doc.line + rept(tpe) } - ) + Doc.char(']') - ) + Doc.lineOrSpace + loop(res) + Doc.char(')')).nested(4) + )) + Doc.char(']') + ) + Doc.line + loop(res) + Doc.char(')')) case Local(v, tpe, _) => - (Doc.text("(var") + Doc.lineOrSpace + Doc.text( + block(Doc.text("(var") + Doc.line + Doc.text( v.sourceCodeRepr - ) + Doc.lineOrSpace + rept(tpe) + Doc.char(')')).nested(4) + ) + Doc.line + rept(tpe) + Doc.char(')')) case Global(p, v, tpe, _) => val pstr = Doc.text(p.asString + "::" + v.sourceCodeRepr) - (Doc.text("(var") + Doc.lineOrSpace + pstr + Doc.lineOrSpace + rept( + block(Doc.text("(var") + Doc.line + pstr + Doc.line + rept( tpe - ) + Doc.char(')')).nested(4) + ) + Doc.char(')')) case App(fn, args, tpe, _) => - val argsDoc = Doc.intercalate(Doc.lineOrSpace, args.toList.map(loop)) - (Doc.text("(ap") + Doc.lineOrSpace + loop( + val argsDoc = block(Doc.intercalate(Doc.line, args.toList.map(loop))) + block(Doc.text("(ap") + Doc.line + loop( fn - ) + Doc.lineOrSpace + argsDoc + Doc.lineOrSpace + rept(tpe) + Doc - .char(')')).nested(4) + ) + Doc.line + argsDoc + Doc.line + rept(tpe) + Doc + .char(')')) case Let(n, b, in, rec, _) => val nm = if (rec.isRecursive) Doc.text("(letrec") else Doc.text("(let") - (nm + Doc.lineOrSpace + Doc.text( + block(nm + Doc.line + Doc.text( n.sourceCodeRepr - ) + Doc.lineOrSpace + loop(b) + Doc.lineOrSpace + loop(in) + Doc.char( + ) + Doc.line + loop(b) + Doc.line + loop(in) + Doc.char( ')' - )).nested(4) + )) case Literal(v, tpe, _) => - (Doc.text("(lit") + Doc.lineOrSpace + Doc.text( + block(Doc.text("(lit") + Doc.line + Doc.text( v.repr - ) + Doc.lineOrSpace + rept(tpe) + Doc.char(')')).nested(4) + ) + Doc.line + rept(tpe) + Doc.char(')')) case Match(arg, branches, _) => implicit val docType: Document[Type] = Document.instance(tpe => rept(tpe)) @@ -116,18 +118,17 @@ sealed abstract class TypedExpr[+T] { self: Product => cpat.document(p) val bstr = branches.toList.map { case (p, t) => - (Doc.char('[') + pat(p) + Doc.comma + Doc.lineOrSpace + loop( - t - ).grouped + Doc.char(']')).nested(4) + block(Doc.char('[') + pat(p) + Doc.comma + Doc.line + loop(t) + Doc.char(']')) } - (Doc.text("(match") + Doc.lineOrSpace + loop(arg) + (Doc.hardLine + - Doc.intercalate(Doc.hardLine, bstr)).nested(4) + Doc.char(')')) - .nested(4) + block(Doc.text("(match") + Doc.line + loop(arg) + block(Doc.hardLine + + Doc.intercalate(Doc.hardLine, bstr)) + Doc.char(')')) } loop(this) } + def reprString: String = repr.render(80) + /** All the free variables in this expression in order encountered and with * duplicates (to see how often they appear) */ @@ -194,7 +195,7 @@ object TypedExpr { type Rho[A] = TypedExpr[A] // an expression with a Rho type (no top level forall) - sealed abstract class Name[A] extends TypedExpr[A] with Product + sealed abstract class Name[+A] extends TypedExpr[A] with Product /** This says that the resulting term is generic on a given param * @@ -213,12 +214,71 @@ object TypedExpr { extends TypedExpr[T] { def tag: T = term.tag } - case class AnnotatedLambda[T]( + case class AnnotatedLambda[+T]( args: NonEmptyList[(Bindable, Type)], expr: TypedExpr[T], tag: T - ) extends TypedExpr[T] - case class Local[T](name: Bindable, tpe: Type, tag: T) extends Name[T] + ) extends TypedExpr[T] { + + // This makes sure the args don't shadow any of the items in freeSet + def unshadow(freeSet: Set[Bindable]): AnnotatedLambda[T] = { + val clashIdent = + if (freeSet.isEmpty) Set.empty[Bindable] + else args.iterator.flatMap { + case (n, _) if freeSet(n) => n :: Nil + case _ => Nil + }.toSet + + if (clashIdent.isEmpty) this + else { + // we have to allocate new variables + type I = Bindable + def inc(n: I, idx: Int): I = + n match { + case Identifier.Name(n) => Identifier.Name(n + idx.toString) + case _ => Identifier.Name("a" + idx.toString) + } + + def alloc(head: (I, Type), tail: List[(I, Type)], avoid: Set[I]): NonEmptyList[(I, Type)] = { + val (ident, tpe) = head + val ident1 = + if (clashIdent(ident)) { + // the following iterator is infinite and distinct, and the avoid + // set is finite, so the get here must terminate in at most avoid.size + // steps + Iterator.from(0) + .map { i => inc(ident, i) } + .collectFirst { case n if !avoid(n) => n } + .get + + } + else ident + + tail match { + case Nil => NonEmptyList.one((ident1, tpe)) + case h :: t => + (ident1, tpe) :: alloc(h, t, avoid + ident1) + } + } + + val avoids = freeSet | freeVarsSet(expr :: Nil) + val newArgs = alloc(args.head, args.tail, avoids) + val resSub = args.iterator.map(_._1) + .zip(newArgs.iterator.map { case (n1, _) => + + { (loc: Local[T]) => Local(n1, loc.tpe, loc.tag) } + }) + .toMap + + // calling .get is safe when enterLambda = true + val expr1 = substituteAll(resSub, expr, enterLambda = true).get + + AnnotatedLambda(newArgs, expr1, tag) + } + } + } + + case class Local[+T](name: Bindable, tpe: Type, tag: T) extends Name[T] case class Global[T](pack: PackageName, name: Identifier, tpe: Type, tag: T) extends Name[T] case class App[T]( @@ -233,7 +293,72 @@ object TypedExpr { in: TypedExpr[T], recursive: RecursionKind, tag: T - ) extends TypedExpr[T] + ) extends TypedExpr[T] { + def unshadowResult(freeSet: Set[Bindable]): Let[T] = { + val clashIdent = + if (freeSet(arg)) Set(arg) + else Set.empty + + if (clashIdent.isEmpty) this + else { + // we have to allocate new let + val avoids = freeSet | freeVarsSet(in :: expr :: Nil) + type I = Bindable + def inc(n: I, idx: Int): I = + n match { + case Identifier.Name(n) => Identifier.Name(n + idx.toString) + case _ => Identifier.Name("a" + idx.toString) + } + + val arg1 = Iterator.from(0) + .map { i => inc(arg, i) } + .collectFirst { case n if !avoids(n) => n } + .get + + val resSub = Map(arg -> + { (loc: Local[T]) => Local(arg1, loc.tpe, loc.tag) } + ) + + // calling .get is safe when enterLambda = true + val in1 = substituteAll(resSub, in, enterLambda = true).get + + copy(arg = arg1, in = in1) + } + } + + def unshadowBoth(freeSet: Set[Bindable]): Let[T] = { + val clashIdent = + if (freeSet(arg)) Set(arg) + else Set.empty + + if (clashIdent.isEmpty) this + else { + // we have to allocate new let + val avoids = freeSet | freeVarsSet(in :: expr :: Nil) + type I = Bindable + def inc(n: I, idx: Int): I = + n match { + case Identifier.Name(n) => Identifier.Name(n + idx.toString) + case _ => Identifier.Name("a" + idx.toString) + } + + val arg1 = Iterator.from(0) + .map { i => inc(arg, i) } + .collectFirst { case n if !avoids(n) => n } + .get + + val resSub = Map(arg -> + { (loc: Local[T]) => Local(arg1, loc.tpe, loc.tag) } + ) + + // calling .get is safe when enterLambda = true + val expr1 = substituteAll(resSub, expr, enterLambda = true).get + val in1 = substituteAll(resSub, in, enterLambda = true).get + + copy(arg = arg1, expr = expr1, in = in1) + } + } + } // TODO, this shouldn't have a type, we know the type from Lit currently case class Literal[T](lit: Lit, tpe: Type, tag: T) extends TypedExpr[T] case class Match[T]( @@ -637,16 +762,13 @@ object TypedExpr { * which has a type forall a. Int which is the same * as Int */ - type Branch = - (Pattern[(PackageName, Constructor), Type], TypedExpr[A]) - val allMatchMetas: F[SortedSet[Type.Meta]] = getMetaTyVars(arg.getType :: branches.foldMap { case (p, _) => allPatternTypes(p) }.toList) val env1 = env + te.getType - def handleBranch(br: Branch): F[Branch] = { + def handleBranch(br: Branch[A]): F[Branch[A]] = { val (p, expr) = br val branchEnv = env1 ++ Pattern .envOf(p, Map.empty)(ident => (None, ident)) @@ -1146,54 +1268,207 @@ object TypedExpr { ex: TypedExpr[A], in: TypedExpr[A], enterLambda: Boolean = true - ): Option[TypedExpr[A]] = { - // if we hit a shadow, we don't need to substitute down - // that branch - @inline def shadows(i: Bindable): Boolean = i === ident - - // free variables in ex are being rebound, - // this causes us to return None - lazy val masks: Bindable => Boolean = - freeVarsSet(ex :: Nil) + ): Option[TypedExpr[A]] = + substituteAll(Map(ident -> { (_: Local[A]) => ex }), in, enterLambda) - def loop(in: TypedExpr[A]): Option[TypedExpr[A]] = + // Invariant, if enterLambda == true, we always return Some + def substituteAll[A]( + table: Map[Bindable, Local[A] => TypedExpr[A]], + in: TypedExpr[A], + enterLambda: Boolean = true + ): Option[TypedExpr[A]] = { + def loop(table: Map[Bindable, Local[A] => TypedExpr[A]], in: TypedExpr[A]): Option[TypedExpr[A]] = in match { - case Local(i, _, _) if i === ident => Some(ex) - case Global(_, _, _, _) | Local(_, _, _) | Literal(_, _, _) => Some(in) + case local @ Local(i, _, _) => + table.get(i) match { + case Some(te) => Some(te(local)) + case None => Some(in) + } + case Global(_, _, _, _) | Literal(_, _, _) => Some(in) case Generic(a, expr) => - loop(expr).map(Generic(a, _)) + loop(table, expr).map(Generic(a, _)) case Annotation(t, tpe) => - loop(t).map(Annotation(_, tpe)) - case AnnotatedLambda(args, res, tag) => + loop(table, t).map(Annotation(_, tpe)) + case lam @ AnnotatedLambda(args, res, tag) => if (!enterLambda) None - else if (args.exists { case (n, _) => masks(n) }) None - else if (args.exists { case (n, _) => shadows(n) }) Some(in) - else loop(res).map(AnnotatedLambda(args, _, tag)) - case App(fn, args, tpe, tag) => - (loop(fn), args.traverse(loop(_))).mapN(App(_, _, tpe, tag)) - case let @ Let(arg, argE, in, rec, tag) => - if (masks(arg)) None - else if (shadows(arg)) { - // recursive shadow blocks both argE and in - if (rec.isRecursive) Some(let) - else loop(argE).map(Let(arg, _, in, rec, tag)) - } else { - (loop(argE), loop(in)).mapN(Let(arg, _, _, rec, tag)) + else { + // this is the same algorithm as python/Code.Expression.substitute + // + // the args here can shadow, so we have to remove any + // items from subMap that have the same Ident + val argsSet = args.iterator.map(_._1).toSet + val nonShadowed = table.filterNot { case (i, _) => argsSet(i) } + // if subFrees is empty, unshadow is a no-op. + // but that is efficiently handled by unshadow + val subFrees = nonShadowed.iterator + .map { case (n, v) => + // TODO this isn't great but we just need to get the free vars from the function + // this assumes the replacement free variables is constant over the type + // which it should be otherwise we can make ill-typed TypedExpr + val dummyTpe = res.getType + freeVarsSet(v(Local(n, tpe = dummyTpe, tag)) :: Nil) + } + .foldLeft(nonShadowed.keySet)(_ | _) + + val AnnotatedLambda(args1, res1, tag1) = lam.unshadow(subFrees) + // now we know that none of args1 shadow anything in subFrees + // so we can just directly substitute nonShadowed on res1 + // put another way: unshadow make substitute "commute" with lambda. + val subRes = substituteAll(nonShadowed, res1, enterLambda = true).get + Some(AnnotatedLambda(args1, subRes , tag1)) } + case App(fn, args, tpe, tag) => + (loop(table, fn), args.traverse(loop(table, _))).mapN(App(_, _, tpe, tag)) + case let @ Let(arg, _, _, _, _) => + if (let.recursive.isRecursive) { + // arg is in scope for argE and in + // the args here can shadow, so we have to remove any + // items from subMap that have the same Ident + val nonShadowed = table - arg + // if subFrees is empty, unshadow is a no-op. + // but that is efficiently handled by unshadow + val subFrees = nonShadowed.iterator + .map { case (n, v) => + // TODO this isn't great but we just need to get the free vars from the function + // this assumes the replacement free variables is constant over the type + // which it should be otherwise we can make ill-typed TypedExpr + val dummyTpe = in.getType + freeVarsSet(v(Local(n, tpe = dummyTpe, let.tag)) :: Nil) + } + .foldLeft(nonShadowed.keySet)(_ | _) + + val Let(arg1, argE1, in1, rec1, tag1) = let.unshadowBoth(subFrees) + // now we know that none of args1 shadow anything in subFrees + // so we can just directly substitute nonShadowed on res1 + // put another way: unshadow make substitute "commute" with lambda. + (substituteAll(nonShadowed, argE1, enterLambda), substituteAll(nonShadowed, in1, enterLambda)) + .mapN(Let(arg1, _, _, rec1, tag1)) + } + else { + // the scopes are different the binding and the result + // the args here can shadow, so we have to remove any + // items from subMap that have the same Ident + val argsSet = Set(arg) + val nonShadowed = table.filterNot { case (i, _) => argsSet(i) } + // if subFrees is empty, unshadow is a no-op. + // but that is efficiently handled by unshadow + val subFrees = nonShadowed.iterator + .map { case (n, v) => + // TODO this isn't great but we just need to get the free vars from the function + // this assumes the replacement free variables is constant over the type + // which it should be otherwise we can make ill-typed TypedExpr + val dummyTpe = in.getType + freeVarsSet(v(Local(n, tpe = dummyTpe, let.tag)) :: Nil) + } + .foldLeft(nonShadowed.keySet)(_ | _) + + val Let(arg1, argE1, in1, rec1, tag1) = let.unshadowResult(subFrees) + // now we know that none of args1 shadow anything in subFrees + // so we can just directly substitute nonShadowed on res1 + // put another way: unshadow make substitute "commute" with lambda. + (loop(table, argE1), loop(nonShadowed, in1)) + .mapN(Let(arg1, _, _, rec1, tag1)) + } case Match(arg, branches, tag) => // Maintain the order we encounter things: - val arg1 = loop(arg) + val arg1 = loop(table, arg) val b1 = branches.traverse { case in @ (p, b) => // these are not free variables in this branch val ns = p.names - if (ns.exists(masks)) None - else if (ns.exists(shadows)) Some(in) - else loop(b).map((p, _)) + + val (table1, (p1, b1)) = + if (ns.isEmpty) (table, in) + else { + // the args here can shadow, so we have to remove any + // items from subMap that have the same Ident + val argsSet = ns.toSet + val nonShadowed = + if (argsSet.isEmpty) table + else table.filterNot { case (i, _) => argsSet(i) } + + // if subFrees is empty, unshadow is a no-op. + // but that is efficiently handled by unshadow + val subFrees = nonShadowed.iterator + .map { case (n, v) => + // TODO this isn't great but we just need to get the free vars from the function + // this assumes the replacement free variables is constant over the type + // which it should be otherwise we can make ill-typed TypedExpr + val dummyTpe = b.getType + freeVarsSet(v(Local(n, tpe = dummyTpe, tag)) :: Nil) + } + .foldLeft(nonShadowed.keySet)(_ | _) + + (nonShadowed, unshadowBranch[A](subFrees, in)) + } + // now we know that none of args1 shadow anything in subFrees + // so we can just directly substitute nonShadowed on res1 + // put another way: unshadow make substitute "commute" with lambda. + loop(table1, b1).map((p1, _)) } (arg1, b1).mapN(Match(_, _, tag)) } - loop(in) + loop(table, in) + } + + private def unshadowBranch[A](freeSet: Set[Bindable], branch: Branch[A]): Branch[A] = { + // we only get in here when p has some names + val (p, b) = branch + val args = NonEmptyList.fromList(p.names) match { + case None => return branch + case Some(argsNel) => argsNel + } + + val clashIdent = + if (freeSet.isEmpty) Set.empty[Bindable] + else args.iterator.filter(freeSet).toSet + + if (clashIdent.isEmpty) branch + else { + // we have to allocate new variables + type I = Bindable + def inc(n: I, idx: Int): I = + n match { + case Identifier.Name(n) => Identifier.Name(n + idx.toString) + case _ => Identifier.Name("a" + idx.toString) + } + + def alloc(ident: I, tail: List[I], avoid: Set[I]): NonEmptyList[I] = { + val ident1 = + if (clashIdent(ident)) { + // the following iterator is infinite and distinct, and the avoid + // set is finite, so the get here must terminate in at most avoid.size + // steps + Iterator.from(0) + .map { i => inc(ident, i) } + .collectFirst { case n if !avoid(n) => n } + .get + + } + else ident + + tail match { + case Nil => NonEmptyList.one(ident1) + case h :: t => + ident1 :: alloc(h, t, avoid + ident1) + } + } + + val avoids = freeSet | freeVarsSet(b :: Nil) + val newArgs = alloc(args.head, args.tail, avoids) + val resSub = args.iterator + .zip(newArgs.iterator.map { n1 => + + { (loc: Local[A]) => Local(n1, loc.tpe, loc.tag) } + }) + .toMap + + // calling .get is safe when enterLambda = true + val b1 = substituteAll(resSub, b, enterLambda = true).get + val p1 = p.substitute(args.iterator.zip(newArgs.iterator).toMap) + + (p1, b1) + } } def substituteTypeVar[A]( @@ -1466,6 +1741,9 @@ object TypedExpr { } } + type Branch[A] = + (Pattern[(PackageName, Constructor), Type], TypedExpr[A]) + def quantVars[A]( forallList: List[(Type.Var.Bound, Kind)], existList: List[(Type.Var.Bound, Kind)], diff --git a/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala b/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala index 18f7f1001..4875c711e 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala @@ -119,6 +119,18 @@ object TypedExprNormalization { private def setType[A](expr: TypedExpr[A], tpe: Type): TypedExpr[A] = if (!tpe.sameAs(expr.getType)) Annotation(expr, tpe) else expr + private def appLambda[A](f1: AnnotatedLambda[A], args: NonEmptyList[TypedExpr[A]], tpe: Type, tag: A): TypedExpr[A] = { + val freesInArgs = TypedExpr.freeVarsSet(args.toList) + val AnnotatedLambda(lamArgs, expr, _) = f1.unshadow(freesInArgs) + // Now that we certainly don't shadow we can convert this: + // ((y1, y2, ..., yn) -> z)(x1, x2, ..., xn) = let y1 = x1 in let y2 = x2 in ... z + val lets = lamArgs.zip(args).map { case ((n, ltpe), arg) => + (n, setType(arg, ltpe)) + } + val expr2 = setType(expr, tpe) + TypedExpr.letAllNonRec(lets, expr2, tag) + } + /** if the te is not in normal form, transform it into normal form */ private def normalizeLetOpt[A, V]( @@ -208,58 +220,24 @@ object TypedExprNormalization { case _ => false } - val ws = Impl.WithScope(scope, ev.substituteCo[TypeEnv](typeEnv)) e1 match { case App(fn, aargs, _, _) if matchesArgs(aargs) && doesntUseArgs(fn) => // x -> f(x) == f (eta conversion) - normalize1(None, setType(fn, te.getType), scope, typeEnv) - case App( - ws.ResolveToLambda(Nil, args1, body, ftag), - aargs, - resT, - atag - ) if namerec.isEmpty => - // args -> (args1 -> e1)(...) - // this is inlining, which we do only when nested directly inside another lambda - // TODO: this is possibly very expensive to always apply. It can really increase - // code size. We probably need better hueristics for when to inline, - // or remove inlining from here unless it can hever hurt and put inlining at a - // different phase. - val fn1 = AnnotatedLambda(args1, body, ftag) - val e2 = App(fn1, aargs, resT, atag) - if (e1 != e2) { - // in this case we have inlined, vs there already being - // a literal lambda being applied - // by normalizing this, it will become a let binding - val e3 = normalize1(None, e2, bodyScope, typeEnv).get - - if (e3.size <= expr.size) { - // we haven't made the code larger - normalize1( - namerec, - AnnotatedLambda(lamArgs, e3, tag), - scope, - typeEnv - ) - } else { - // inlining will make the code larger that it was originally - if ((e1 eq expr) && (lamArgs === lamArgs0)) None - else Some(AnnotatedLambda(lamArgs, e1, tag)) - } - } else { - if ((e1 eq expr) && (lamArgs === lamArgs0)) None - else Some(AnnotatedLambda(lamArgs, e1, tag)) - } + // note, e1 is already normalized, so fn is normalized + Some(setType(fn, te.getType)) case Let(arg1, ex, in, rec, tag1) - if doesntUseArgs(ex) && doesntShadow(arg1) => + if !Impl.isSimple(ex, lambdaSimple = true) && doesntUseArgs(ex) && doesntShadow(arg1)=> // x -> // y = z // f(y) // same as: // y = z // x -> f(y) - // avoid recomputing y + // avoid recomputing y if y is not simple. Note, we consider a lambda simple + // since when compiling we can lift lambdas out anyway, so they are at most 1 allocation + // but possibly 0. + // // TODO: we could reorder Lets if we have several in a row normalize1( None, @@ -269,8 +247,10 @@ object TypedExprNormalization { ) case m @ Match(arg1, branches, tag1) if lamArgs.forall { case (arg, _) => arg1.notFree(arg) - } => - // same as above: if match does not depend on lambda arg, lift it out + } && ((branches.length > 1) || !Impl.isSimple(arg1, lambdaSimple = true)) => + // x -> match z: w + // convert to match z: x -> w + // but don't bother if the arg is simple or there is only 1 branch + simple arg val b1 = branches.traverse { case (p, b) => if ( !lamArgs.exists { case (arg, _) => p.names.contains(arg) } @@ -326,24 +306,24 @@ object TypedExprNormalization { lazy val a1 = ListUtil.mapConserveNel(args) { a => normalize1(None, a, scope, typeEnv).get } + val ws = Impl.WithScope(scope, ev.substituteCo[TypeEnv](typeEnv)) f1 match { // TODO: what if f1: Generic(_, AnnotatedLambda(_, _, _)) // we should still be able ton convert this to a let by // instantiating to the right args - case AnnotatedLambda(lamArgs, expr, _) => - // (y -> z)(x) = let y = x in z - val lets = lamArgs.zip(args).map { case ((n, ltpe), arg) => - (n, setType(arg, ltpe)) - } - val expr2 = setType(expr, tpe) - val l = TypedExpr.letAllNonRec(lets, expr2, tag) + case ws.ResolveToLambda(Nil, args1, body, ftag) => + val lam = AnnotatedLambda(args1, body, ftag) + val l = appLambda[A](lam, args, tpe, tag) + normalize1(namerec, l, scope, typeEnv) + case lam @ AnnotatedLambda(_, _, _) => + val l = appLambda[A](lam, args, tpe, tag) normalize1(namerec, l, scope, typeEnv) case Let(arg1, ex, in, rec, tag1) if a1.forall(_.notFree(arg1)) => // (app (let x y z) w) == (let x y (app z w)) if w does not have x free normalize1( namerec, - Let(arg1, ex, App(in, a1, tpe, tag), rec, tag1), + Let(arg1, ex, App(in, args, tpe, tag), rec, tag1), scope, typeEnv ) @@ -540,6 +520,9 @@ object TypedExprNormalization { } object ResolveToLambda { + // this is a parameter that we can tune to change inlining Global Lambdas + val MaxSize = 10 + // TODO: don't we need to worry about the type environment for locals? They // can also capture type references to outer Generics def unapply(te: TypedExpr[A]): Option[ @@ -577,7 +560,9 @@ object TypedExprNormalization { Some((Nil, args, expr, ltag)) case Global(p, n: Bindable, _, _) => scope.getGlobal(p, n).flatMap { - case (RecursionKind.NonRecursive, te, scope1) => + // only inline global lambdas if they are somewhat small, otherwise we will + // tend to transitively inline everything into one big function and blow the stack + case (RecursionKind.NonRecursive, te, scope1) if te.size < MaxSize => val s1 = WithScope(scope1, typeEnv) te match { case s1.ResolveToLambda(frees, args, expr, ltag) => @@ -599,6 +584,7 @@ object TypedExprNormalization { } case Local(nm, _, _) => scope.getLocal(nm).flatMap { + // Local lambdas tend to be small, so inline them always if we can case (RecursionKind.NonRecursive, te, scope1) => val s1 = WithScope(scope1, typeEnv) te match { @@ -624,17 +610,30 @@ object TypedExprNormalization { } } + final def isSimpleNotTail[A](ex: TypedExpr[A], lambdaSimple: Boolean): Boolean = + isSimple(ex, lambdaSimple) + @annotation.tailrec final def isSimple[A](ex: TypedExpr[A], lambdaSimple: Boolean): Boolean = ex match { case Literal(_, _, _) | Local(_, _, _) | Global(_, _, _, _) => true + case App(_, _, _, _) => false case Annotation(t, _) => isSimple(t, lambdaSimple) case Generic(_, t) => isSimple(t, lambdaSimple) case AnnotatedLambda(_, _, _) => // maybe inline lambdas so we can possibly // apply (x -> f)(g) => let x = g in f lambdaSimple - case _ => false + case Let(_, ex, in, _, _) => + isSimpleNotTail(ex, lambdaSimple) && isSimple(in, lambdaSimple) + case Match(arg, branches, _) => + branches.tail.isEmpty && isSimpleNotTail(arg, lambdaSimple) && { + // match f: case p: r + // is the same as + // let p = f in r + val (_, rest) = branches.head + isSimple(rest, lambdaSimple) + } } sealed abstract class EvalResult[A] diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala index 219381a90..f1e7e6e99 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala @@ -233,6 +233,7 @@ object ClangGen { def directFn(p: PackageName, b: Bindable): T[Option[(Code.Ident, Int)]] def directFn(b: Bindable): T[Option[(Code.Ident, Boolean, Int)]] def inTop[A](p: PackageName, bn: Bindable)(ta: T[A]): T[A] + def inFnStatement[A](in: T[A]): T[A] def currentTop: T[Option[(PackageName, Bindable)]] def staticValueName(p: PackageName, b: Bindable): T[Code.Ident] def constructorFn(p: PackageName, b: Bindable): T[Code.Ident] @@ -1026,7 +1027,7 @@ object ClangGen { } def fnStatement(fnName: Code.Ident, fn: FnExpr): T[Code.Statement] = - fn match { + inFnStatement(fn match { case Lambda(captures, name, args, expr) => val body = innerToValue(expr).map(Code.returnValue(_)) val body1 = name match { @@ -1047,7 +1048,7 @@ object ClangGen { } } yield Code.DeclareFn(Nil, Code.TypeIdent.BValue, fnName, allArgs.toList, Some(Code.block(fnBody))) } - } + }) def renderTop(p: PackageName, b: Bindable, expr: Expr): T[Unit] = inTop(p, b) { expr match { @@ -1112,6 +1113,31 @@ object ClangGen { def catsMonad[S]: Monad[StateT[EitherT[Eval, Error, *], S, *]] = implicitly new Env { + sealed abstract class BindingKind { + def ident: Code.Ident + } + object BindingKind { + case class Normal(bn: Bindable, idx: Int) extends BindingKind { + val ident = Code.Ident(Idents.escape("__bsts_b_", bn.asString + idx.toString)) + } + case class Recursive(ident: Code.Ident, isClosure: Boolean, arity: Int, idx: Int) extends BindingKind + } + case class BindState(count: Int, stack: List[BindingKind]) { + def pop: BindState = + // by invariant this tail should never fail + copy(stack = stack.tail) + + def nextBind(bn: Bindable): BindState = + copy(count = count + 1, BindingKind.Normal(bn, count) :: stack) + + def nextRecursive(fnName: Code.Ident, isClosure: Boolean, arity: Int): BindState = + copy(count = count + 1, BindingKind.Recursive(fnName, isClosure, arity, count) :: stack) + } + + object BindState { + val empty: BindState = BindState(0, Nil) + } + case class State( allValues: AllValues, externals: ExternalResolver, @@ -1119,7 +1145,7 @@ object ClangGen { includes: Chain[Code.Include], stmts: Chain[Code.Statement], currentTop: Option[(PackageName, Bindable)], - binds: Map[Bindable, NonEmptyList[Either[((Code.Ident, Boolean, Int), Int), Int]]], + binds: Map[Bindable, BindState], counter: Long, identCache: Map[Expr, Code.Ident] ) { @@ -1168,46 +1194,50 @@ object ClangGen { Eval.now(Right((s, a))) ) + def update[A](fn: State => (State, A)): T[A] = + StateT(s => EitherT[Eval, Error, (State, A)](Eval.now(Right(fn(s))))) + + def tryUpdate[A](fn: State => Either[Error, (State, A)]): T[A] = + StateT(s => EitherT[Eval, Error, (State, A)](Eval.now(fn(s)))) + + def read[A](fn: State => A): T[A] = + StateT(s => EitherT[Eval, Error, (State, A)](Eval.now(Right((s, fn(s)))))) + + def tryRead[A](fn: State => Either[Error, A]): T[A] = + StateT(s => EitherT[Eval, Error, (State, A)](Eval.now(fn(s).map((s, _))))) + def globalIdent(pn: PackageName, bn: Bindable): T[Code.Ident] = - StateT { s => + tryUpdate { s => s.externals(pn, bn) match { case Some((incl, ident, _)) => // TODO: suspect that we are ignoring arity here val withIncl = s.include(incl) - result(withIncl, ident) + Right((withIncl, ident)) case None => val key = (pn, bn) s.allValues.get(key) match { - case Some((_, ident)) => result(s, ident) - case None => errorRes(Error.UnknownValue(pn, bn)) + case Some((_, ident)) => Right((s, ident)) + case None => Left(Error.UnknownValue(pn, bn)) } } } def bind[A](bn: Bindable)(in: T[A]): T[A] = { - val init: T[Unit] = StateT { s => - val v = s.binds.get(bn) match { - case None => NonEmptyList.one(Right(0)) - case Some(items @ NonEmptyList(Right(idx), _)) => - Right(idx + 1) :: items - case Some(items @ NonEmptyList(Left((_, idx)), _)) => - Right(idx + 1) :: items + val init: T[Unit] = update { s => + val bs0 = s.binds.get(bn) match { + case None => BindState.empty + case Some(bs) => bs } - result(s.copy(binds = s.binds.updated(bn, v)), ()) + val bs1 = bs0.nextBind(bn) + (s.copy(binds = s.binds.updated(bn, bs1)), ()) } - val uninit: T[Unit] = StateT { s => - s.binds.get(bn) match { - case Some(NonEmptyList(_, tail)) => - val s1 = NonEmptyList.fromList(tail) match { - case None => - s.copy(binds = s.binds - bn) - case Some(prior) => - s.copy(binds = s.binds.updated(bn, prior)) - } - result(s1, ()) + val uninit: T[Unit] = update { s => + val bs1 = s.binds.get(bn) match { + case Some(bs) => bs.pop case None => sys.error(s"bindable $bn no longer in $s") } + (s.copy(binds = s.binds.updated(bn, bs1)), ()) } for { @@ -1217,18 +1247,10 @@ object ClangGen { } yield a } def getBinding(bn: Bindable): T[Code.Ident] = - StateT { s => + tryRead { s => s.binds.get(bn) match { - case Some(stack) => - stack.head match { - case Right(idx) => - result(s, Code.Ident(Idents.escape("__bsts_b_", bn.asString + idx.toString))) - case Left(((ident, _, _), _)) => - // TODO: suspicious to ignore isClosure and arity here - // probably need to conv - result(s, ident) - } - case None => errorRes(Error.Unbound(bn, s.currentTop)) + case Some(bs) => Right(bs.stack.head.ident) + case None => Left(Error.Unbound(bn, s.currentTop)) } } def bindAnon[A](idx: Long)(in: T[A]): T[A] = @@ -1241,30 +1263,21 @@ object ClangGen { // a recursive function needs to remap the Bindable to the top-level mangling def recursiveName[A](fnName: Code.Ident, bn: Bindable, isClosure: Boolean, arity: Int)(in: T[A]): T[A] = { - val init: T[Unit] = StateT { s => - val entry = (fnName, isClosure, arity) - val v = s.binds.get(bn) match { - case None => NonEmptyList.one(Left((entry, -1))) - case Some(items @ NonEmptyList(Right(idx), _)) => - Left((entry, idx)) :: items - case Some(items @ NonEmptyList(Left((_, idx)), _)) => - Left((entry, idx)) :: items - } - result(s.copy(binds = s.binds.updated(bn, v)), ()) + val init: T[Unit] = update { s => + val bs0 = s.binds.get(bn) match { + case Some(bs) => bs + case None => BindState.empty + } + val bs1 = bs0.nextRecursive(fnName, isClosure, arity) + (s.copy(binds = s.binds.updated(bn, bs1)), ()) } - val uninit: T[Unit] = StateT { s => - s.binds.get(bn) match { - case Some(NonEmptyList(_, tail)) => - val s1 = NonEmptyList.fromList(tail) match { - case None => - s.copy(binds = s.binds - bn) - case Some(prior) => - s.copy(binds = s.binds.updated(bn, prior)) - } - result(s1, ()) + val uninit: T[Unit] = update { s => + val bs1 = s.binds.get(bn) match { + case Some(bs) => bs.pop case None => sys.error(s"bindable $bn no longer in $s") } + (s.copy(binds = s.binds.updated(bn, bs1)), ()) } for { @@ -1310,20 +1323,28 @@ object ClangGen { } def directFn(b: Bindable): T[Option[(Code.Ident, Boolean, Int)]] = - StateT { s => + read { s => s.binds.get(b) match { - case Some(NonEmptyList(Left((c, _)), _)) => - result(s, Some(c)) + case Some(BindState(_, BindingKind.Recursive(n, c, a, _) :: _)) => + Some((n, c, a)) case _ => - result(s, None) + None } } + def inFnStatement[A](ta: T[A]): T[A] = { + for { + bindState <- update { (s: State) => (s.copy(binds = Map.empty), s.binds) } + a <- ta + _ <- update { (s: State) => (s.copy(binds = bindState), ()) } + } yield a + } + def inTop[A](p: PackageName, bn: Bindable)(ta: T[A]): T[A] = for { - _ <- StateT { (s: State) => result(s.copy(currentTop = Some((p, bn))), ())} + bindState <- update { (s: State) => (s.copy(binds = Map.empty, currentTop = Some((p, bn))), s.binds)} a <- ta - _ <- StateT { (s: State) => result(s.copy(currentTop = None), ()) } + _ <- update { (s: State) => (s.copy(currentTop = None, binds = bindState), ()) } } yield a val currentTop: T[Option[(PackageName, Bindable)]] = diff --git a/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala b/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala index 0d1673db8..a42513d2c 100644 --- a/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala +++ b/core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala @@ -47,41 +47,53 @@ object PythonGen { private object Impl { + case class BindState(binding: Bindable, count: Int, stack: List[Code.Ident]) { + def currentOption: Option[Code.Ident] = stack.headOption + + def current: Code.Ident = + stack match { + case h :: _ => h + case Nil => sys.error(s"invariant violation: $binding, count = $count has no bindings.") + } + + def next: (BindState, Code.Ident) = { + val pname = Code.Ident(Idents.escape("___b", binding.asString + count.toString)) + (copy(count = count + 1, stack = pname :: stack), pname) + } + + def pop: BindState = + stack match { + case _ :: tail => copy(stack = tail) + case Nil => sys.error(s"invariant violation: $binding, count = $count has no bindings to pop") + } + } + object BindState { + def empty(b: Bindable): BindState = BindState(b, 0, Nil) + } + case class EnvState( imports: Map[Module, Code.Ident], - bindings: Map[Bindable, (Int, List[Code.Ident])], + bindings: Map[Bindable, BindState], tops: Set[Bindable], nextTmp: Long ) { - private def bindInc(b: Bindable, inc: Int)( - fn: Int => Code.Ident - ): (EnvState, Code.Ident) = { - val (c, s) = bindings.getOrElse(b, (0, Nil)) - val pname = fn(c) + def bind(b: Bindable): (EnvState, Code.Ident) = { + val bs = bindings.getOrElse(b, BindState.empty(b)) + val (bs1, pname) = bs.next ( copy( - bindings = bindings.updated(b, (c + inc, pname :: s)) + bindings = bindings.updated(b, bs1) ), pname ) } - def bind(b: Bindable): (EnvState, Code.Ident) = - bindInc(b, 1) { c => - Code.Ident(Idents.escape("___b", b.asString + c.toString)) - } - - // in loops we need to substitute - // bindings for mutable variables - def subs(b: Bindable, c: Code.Ident): EnvState = - bindInc(b, 0)(_ => c)._1 - def deref(b: Bindable): Code.Ident = // see if we are shadowing, or top level - bindings.get(b) match { - case Some((_, h :: _)) => h + bindings.get(b).flatMap(_.currentOption) match { + case Some(h) => h case _ if tops(b) => escape(b) case other => // $COVERAGE-OFF$ @@ -93,8 +105,8 @@ object PythonGen { def unbind(b: Bindable): EnvState = bindings.get(b) match { - case Some((cnt, _ :: tail)) => - copy(bindings = bindings.updated(b, (cnt, tail))) + case Some(bs) => + copy(bindings = bindings.updated(b, bs.pop)) case other => // $COVERAGE-OFF$ throw new IllegalStateException( @@ -167,10 +179,6 @@ object PythonGen { def bind(b: Bindable): Env[Code.Ident] = Impl.env(_.bind(b)) - // point this name to the top level name - def subs(b: Bindable, i: Code.Ident): Env[Unit] = - Impl.update(_.subs(b, i)) - // get the mapping for a name in scope def deref(b: Bindable): Env[Code.Ident] = Impl.read(_.deref(b)) @@ -1567,11 +1575,8 @@ object PythonGen { ): Env[Statement] = expr match { case Lambda(captures, _, args, body) => - // we can ignore name because python already allows recursion - // we can use topLevelName on makeDefs since they are already - // shadowing in the same rules as bosatsu ( - args.traverse(Env.topLevelName(_)), + args.traverse(Env.bind(_)), makeSlots(captures, slotName)(loop(body, _)) ) .mapN { case (as, (slots, body)) => @@ -1580,7 +1585,7 @@ object PythonGen { Env.makeDef(name, as, body) :: Nil ) - } + } <* args.traverse_(Env.unbind(_)) } def makeSlots[A](captures: List[Expr], slotName: Option[Code.Ident])( @@ -1621,7 +1626,7 @@ object PythonGen { case Some(n) => Env.bind(n) } ( - args.traverse(Env.topLevelName(_)), + args.traverse(Env.bind(_)), defName, makeSlots(captures, slotName)(loop(res, _)) ) @@ -1634,7 +1639,8 @@ object PythonGen { defn = Env.makeDef(defName, args, v) block = Code.blockFromList(prefix.toList ::: defn :: Nil) } yield block.withValue(defName) - } + } <* args.traverse_(Env.unbind(_)) + case WhileExpr(cond, effect, res) => (boolExpr(cond, slotName), loop(effect, slotName), loop(res, slotName), Env.newAssignableVar) .mapN { (cond, effect, res, c) => diff --git a/core/src/test/scala/org/bykn/bosatsu/PatternTest.scala b/core/src/test/scala/org/bykn/bosatsu/PatternTest.scala index bac97a19e..577679e11 100644 --- a/core/src/test/scala/org/bykn/bosatsu/PatternTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/PatternTest.scala @@ -10,7 +10,7 @@ import org.scalatest.funsuite.AnyFunSuite class PatternTest extends AnyFunSuite { implicit val generatorDrivenConfig: PropertyCheckConfiguration = - PropertyCheckConfiguration(minSuccessful = 1000) + PropertyCheckConfiguration(minSuccessful = 5000) val patGen = Gen.choose(0, 5).flatMap(Generators.genPattern(_)) @@ -152,4 +152,35 @@ class PatternTest extends AnyFunSuite { } } } + + test("substitute identity is identity") { + forAll(patGen, Gen.listOf(Generators.bindIdentGen)) { (p, list) => + assert(p.substitute(list.map(b => (b, b)).toMap) == p) + } + } + + test("substitute names homomorphism") { + import Identifier._ + + def law[A, B](p: Pattern[A, B], map: Map[Bindable, Bindable]) = { + val subsP = p.substitute(map) + assert(subsP.names.distinct == p.names.map(n => map.getOrElse(n, n)).distinct, s"got $subsP") + } + + def b(s: String) = Identifier.Name(s) + + { + import Pattern._ + import StrPart._ + import Lit.Str + + val p = Union(Var(Name("a")), NonEmptyList(StrPat(NonEmptyList(NamedStr(Name("k")), List(LitStr("wrk"), WildChar))), List(Named(Name("hqZ9aeuAood"), WildCard), Literal(Str("q5VgEdksu")), WildCard))) + + law(p, Map(b("k") -> b("a"))) + } + + forAll(patGen, Gen.mapOf(Gen.zip(Generators.bindIdentGen, Generators.bindIdentGen))) { (p, map) => + law(p, map) + } + } } diff --git a/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala b/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala index ca7b8428f..1c82e0cbe 100644 --- a/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/TypedExprTest.scala @@ -416,8 +416,8 @@ foo = _ -> 1 // substitution in let with a masking val let1 = let("y", varTE("x", intTpe), varTE("y", intTpe)) assert( - TypedExpr.substitute(Identifier.Name("x"), varTE("y", intTpe), let1) == - None + TypedExpr.substitute(Identifier.Name("x"), varTE("y", intTpe), let1).map(_.reprString) == + Some("(let y0 (var y Bosatsu/Predef::Int) (var y0 Bosatsu/Predef::Int))") ) } @@ -516,37 +516,46 @@ foo = _ -> 1 ) ) - test("we can't inline using a shadow: let x = y in let y = z in x(y, y)") { - // we can't inline a shadow + test("we can inline using a shadow: let x = y in let y = z(43) in x(y)(y)") { + // we can inline a shadow by unshadowing y to be y1 // x = y // y = z(43) // x(y, y) - assert(TypedExprNormalization.normalize(normalLet) == None) + val normed = TypedExprNormalization.normalize(normalLet) + assert(normed.map(_.repr.render(80)) == Some("""(let + y0 + (ap + (var z Bosatsu/Predef::Int) + (lit 43 Bosatsu/Predef::Int) + Bosatsu/Predef::Int) + (ap + (ap + (var y Bosatsu/Predef::Int) + (var y0 Bosatsu/Predef::Int) + Bosatsu/Predef::Int) + (var y0 Bosatsu/Predef::Int) + Bosatsu/Predef::Int))""")) } test("if w doesn't have x free: (app (let x y z) w) == let x y (app z w)") { assert( TypedExprNormalization.normalize( app(normalLet, varTE("w", intTpe), intTpe) - ) == + ).map(_.reprString) == Some( - let( - "x", - varTE("y", intTpe), let( - "y", + "y0", app(varTE("z", intTpe), int(43), intTpe), app( app( - app(varTE("x", intTpe), varTE("y", intTpe), intTpe), - varTE("y", intTpe), + app(varTE("y", intTpe), varTE("y0", intTpe), intTpe), + varTE("y0", intTpe), intTpe ), varTE("w", intTpe), intTpe ) - ) - ) + ).reprString ) ) @@ -859,18 +868,18 @@ enum L[a]: E, NE(head: a, tail: L[a]) x = ( def go(y, z): - def loop(z): - recur z: + def loop(z1): + recur z1: case E: y case NE(_, t): loop(t) loop(z) - fn1 = z0 -> go(1, z0) + fn1 = z -> go(1, z) fn1(NE(1, NE(2, E))) ) """) { te2 => - assert(te1.void == te2.void, s"${te1.repr} != ${te2.repr}") + assert(te1.void == te2.void, s"\n${te1.reprString}\n\n!=\n\n${te2.reprString}") } } } diff --git a/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala b/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala index 8f5129d21..db49b5d38 100644 --- a/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala +++ b/core/src/test/scala/org/bykn/bosatsu/codegen/clang/ClangGenTest.scala @@ -75,13 +75,13 @@ int main(int argc, char** argv) { #include #include "gc.h" -BValue __bsts_t_closure__loop0(BValue* __bstsi_slot, BValue __bsts_b_list1) { - if (get_variant(__bsts_b_list1) == 0) { +BValue __bsts_t_closure__loop0(BValue* __bstsi_slot, BValue __bsts_b_list0) { + if (get_variant(__bsts_b_list0) == 0) { return __bstsi_slot[0]; } else { - BValue __bsts_b_h0 = get_enum_index(__bsts_b_list1, 0); - BValue __bsts_b_t0 = get_enum_index(__bsts_b_list1, 1); + BValue __bsts_b_h0 = get_enum_index(__bsts_b_list0, 0); + BValue __bsts_b_t0 = get_enum_index(__bsts_b_list0, 1); return call_fn2(__bstsi_slot[1], __bsts_b_h0, __bsts_t_closure__loop0(__bstsi_slot, __bsts_b_t0)); @@ -113,14 +113,14 @@ int main(int argc, char** argv) { #include "gc.h" BValue __bsts_t_closure0(BValue* __bstsi_slot, - BValue __bsts_b_lst1, - BValue __bsts_b_item1) { + BValue __bsts_b_lst0, + BValue __bsts_b_item0) { BValue __bsts_a_0; BValue __bsts_a_1; BValue __bsts_a_3; BValue __bsts_a_5; - __bsts_a_3 = __bsts_b_lst1; - __bsts_a_5 = __bsts_b_item1; + __bsts_a_3 = __bsts_b_lst0; + __bsts_a_5 = __bsts_b_item0; __bsts_a_0 = alloc_enum0(1); _Bool __bsts_l_cond1; __bsts_l_cond1 = get_variant_value(__bsts_a_0) == 1;