diff --git a/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala b/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala index 496898f54..a222086c8 100644 --- a/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala +++ b/core/src/main/scala/org/bykn/bosatsu/TypedExprNormalization.scala @@ -25,6 +25,9 @@ object TypedExprNormalization { def -(key: Bindable): Scope[A] = FixType.fix[ScopeT[A, *]](FixType.unfix[ScopeT[A, *]](scope) - (None -> key)) + def --(keys: Iterable[Bindable]): Scope[A] = + keys.foldLeft(scope)(_ - _) + def getLocal(key: Bindable): Option[(RecursionKind, TypedExpr[A], Scope[A])] = FixType.unfix[ScopeT[A, *]](scope).get((None, key)) @@ -138,7 +141,7 @@ object TypedExprNormalization { lazy val anons: Iterator[Bindable] = Expr.nameIterator() .filterNot(expr.freeVarsDup.toSet) - val e1 = normalize1(None, expr, scope, typeEnv).get + val e1 = normalize1(None, expr, scope -- lamArgs0.toList.map(_._1), typeEnv).get var changed = false val lamArgs = lamArgs0.map { case (n, t) => @@ -182,15 +185,17 @@ object TypedExprNormalization { case App(fn, aargs, _, _) if matchesArgs(aargs) && doesntUseArgs(fn) => // x -> f(x) == f (eta conversion) normalize1(None, setType(fn, te.getType), scope, typeEnv) - /* - case App(ws.ResolveToLambda(Nil, args1, body, ftag), aargs, resT, tag) => + case App(ws.ResolveToLambda(Nil, args1, body, ftag), aargs, resT, atag) if namerec.isEmpty => // args -> (args1 -> e1)(...) // this is inlining, which we do only when nested directly inside another lambda + // TODO: this is possibly very expensive to always apply. It can really increase + // code size. We probably need better hueristics for when to inline, + // or remove inlining from here unless it can hever hurt and put inlining at a + // different phase. val fn1 = AnnotatedLambda(args1, body, ftag) normalize1(namerec, - AnnotatedLambda(lamArgs, App(fn1, aargs, resT, tag), tag), + AnnotatedLambda(lamArgs, App(fn1, aargs, resT, atag), tag), scope, typeEnv) - */ case Let(arg1, ex, in, rec, tag1) if doesntUseArgs(ex) && doesntShadow(arg1) => // x -> // y = z @@ -351,7 +356,7 @@ object TypedExprNormalization { def ncount(shadows: Iterable[Bindable], e: TypedExpr[A]): (Int, TypedExpr[A]) = // the final result of the branch is what is assigned to the name - normalizeLetOpt(None, e, shadows.foldLeft(scope)(_ - _), typeEnv) match { + normalizeLetOpt(None, e, scope -- shadows, typeEnv) match { case None => (0, e) case Some(e) => (1, e) }