Skip to content

Commit

Permalink
fix lurking bug in python generation
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek committed Dec 12, 2024
1 parent b79696f commit 80eb379
Showing 1 changed file with 38 additions and 32 deletions.
70 changes: 38 additions & 32 deletions core/src/main/scala/org/bykn/bosatsu/codegen/python/PythonGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,41 +47,53 @@ object PythonGen {

private object Impl {

case class BindState(binding: Bindable, count: Int, stack: List[Code.Ident]) {
def currentOption: Option[Code.Ident] = stack.headOption

def current: Code.Ident =
stack match {
case h :: _ => h
case Nil => sys.error(s"invariant violation: $binding, count = $count has no bindings.")
}

def next: (BindState, Code.Ident) = {
val pname = Code.Ident(Idents.escape("___b", binding.asString + count.toString))
(copy(count = count + 1, stack = pname :: stack), pname)
}

def pop: BindState =
stack match {
case _ :: tail => copy(stack = tail)
case Nil => sys.error(s"invariant violation: $binding, count = $count has no bindings to pop")
}
}
object BindState {
def empty(b: Bindable): BindState = BindState(b, 0, Nil)
}

case class EnvState(
imports: Map[Module, Code.Ident],
bindings: Map[Bindable, (Int, List[Code.Ident])],
bindings: Map[Bindable, BindState],
tops: Set[Bindable],
nextTmp: Long
) {

private def bindInc(b: Bindable, inc: Int)(
fn: Int => Code.Ident
): (EnvState, Code.Ident) = {
val (c, s) = bindings.getOrElse(b, (0, Nil))
val pname = fn(c)
def bind(b: Bindable): (EnvState, Code.Ident) = {
val bs = bindings.getOrElse(b, BindState.empty(b))
val (bs1, pname) = bs.next

(
copy(
bindings = bindings.updated(b, (c + inc, pname :: s))
bindings = bindings.updated(b, bs1)
),
pname
)
}

def bind(b: Bindable): (EnvState, Code.Ident) =
bindInc(b, 1) { c =>
Code.Ident(Idents.escape("___b", b.asString + c.toString))
}

// in loops we need to substitute
// bindings for mutable variables
def subs(b: Bindable, c: Code.Ident): EnvState =
bindInc(b, 0)(_ => c)._1

def deref(b: Bindable): Code.Ident =
// see if we are shadowing, or top level
bindings.get(b) match {
case Some((_, h :: _)) => h
bindings.get(b).flatMap(_.currentOption) match {
case Some(h) => h
case _ if tops(b) => escape(b)
case other =>
// $COVERAGE-OFF$
Expand All @@ -93,8 +105,8 @@ object PythonGen {

def unbind(b: Bindable): EnvState =
bindings.get(b) match {
case Some((cnt, _ :: tail)) =>
copy(bindings = bindings.updated(b, (cnt, tail)))
case Some(bs) =>
copy(bindings = bindings.updated(b, bs.pop))
case other =>
// $COVERAGE-OFF$
throw new IllegalStateException(
Expand Down Expand Up @@ -167,10 +179,6 @@ object PythonGen {
def bind(b: Bindable): Env[Code.Ident] =
Impl.env(_.bind(b))

// point this name to the top level name
def subs(b: Bindable, i: Code.Ident): Env[Unit] =
Impl.update(_.subs(b, i))

// get the mapping for a name in scope
def deref(b: Bindable): Env[Code.Ident] =
Impl.read(_.deref(b))
Expand Down Expand Up @@ -1567,11 +1575,8 @@ object PythonGen {
): Env[Statement] =
expr match {
case Lambda(captures, _, args, body) =>
// we can ignore name because python already allows recursion
// we can use topLevelName on makeDefs since they are already
// shadowing in the same rules as bosatsu
(
args.traverse(Env.topLevelName(_)),
args.traverse(Env.bind(_)),
makeSlots(captures, slotName)(loop(body, _))
)
.mapN { case (as, (slots, body)) =>
Expand All @@ -1580,7 +1585,7 @@ object PythonGen {
Env.makeDef(name, as, body) ::
Nil
)
}
} <* args.traverse_(Env.unbind(_))
}

def makeSlots[A](captures: List[Expr], slotName: Option[Code.Ident])(
Expand Down Expand Up @@ -1621,7 +1626,7 @@ object PythonGen {
case Some(n) => Env.bind(n)
}
(
args.traverse(Env.topLevelName(_)),
args.traverse(Env.bind(_)),
defName,
makeSlots(captures, slotName)(loop(res, _))
)
Expand All @@ -1634,7 +1639,8 @@ object PythonGen {
defn = Env.makeDef(defName, args, v)
block = Code.blockFromList(prefix.toList ::: defn :: Nil)
} yield block.withValue(defName)
}
} <* args.traverse_(Env.unbind(_))

case WhileExpr(cond, effect, res) =>
(boolExpr(cond, slotName), loop(effect, slotName), loop(res, slotName), Env.newAssignableVar)
.mapN { (cond, effect, res, c) =>
Expand Down

0 comments on commit 80eb379

Please sign in to comment.