Skip to content

Commit

Permalink
remove DB-indexes and runtime check; simplify freeVars computation in…
Browse files Browse the repository at this point in the history
… closure-conversion
  • Loading branch information
nickchapman-da committed Dec 6, 2021
1 parent 21899ff commit d051ead
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,22 +90,21 @@ object PlaySpeedy {
def examples: Map[String, (Int, SExpr)] = {

def num(n: Long): SExpr = SEValue(SInt64(n))
def mkVar(level: Int) = SEVarLevel(level)

def seVar(abs: Int, rel: Int) = SEVarLevel(abs, rel)

// The trailing numeral is the number of args at the scala level
// The trailing numeral is the number of args at the scala mkVar

def decrement1(x: SExpr): SExpr = SEApp(SEBuiltin(SBSubInt64), List(x, SEValue(SInt64(1))))
val decrement = SEAbs(1, decrement1(seVar(0, 1)))
val decrement = SEAbs(1, decrement1(mkVar(0)))

def subtract2(x: SExpr, y: SExpr): SExpr = SEApp(SEBuiltin(SBSubInt64), List(x, y))
val subtract = SEAbs(2, subtract2(seVar(0, 2), seVar(1, 1)))
val subtract = SEAbs(2, subtract2(mkVar(0), mkVar(1)))

def twice2(f: SExpr, x: SExpr): SExpr = SEApp(f, List(SEApp(f, List(x))))
val twice = SEAbs(2, twice2(seVar(3, 2), seVar(4, 1)))
val twice = SEAbs(2, twice2(mkVar(3), mkVar(4)))

def thrice2(f: SExpr, x: SExpr): SExpr = SEApp(f, List(SEApp(f, List(SEApp(f, List(x))))))
val thrice = SEAbs(2, thrice2(seVar(0, 2), seVar(1, 1)))
val thrice = SEAbs(2, thrice2(mkVar(0), mkVar(1)))

val examples = List(
(
Expand Down Expand Up @@ -145,8 +144,8 @@ object PlaySpeedy {
SEApp(
twice,
List(
SEAbs(1, subtract2(seVar(3, 1), subtract2(seVar(0, 4), seVar(2, 2)))),
seVar(1, 2),
SEAbs(1, subtract2(mkVar(3), subtract2(mkVar(0), mkVar(2)))),
mkVar(1),
),
),
), //100
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,43 +18,36 @@ import scala.annotation.tailrec

private[speedy] object ClosureConversion {

private[speedy] def closureConvert(source0: source.SExpr): target.SExpr = {
import source.SEVarLevel

case class Abs(a: Int) // absolute variable index, determined by tracking sourceDepth
private[speedy] def closureConvert(source0: source.SExpr): target.SExpr = {

case class Env(sourceDepth: Int, mapping: Map[Abs, target.SELoc], targetDepth: Int) {
case class Env(sourceDepth: Int, mapping: Map[SEVarLevel, target.SELoc], targetDepth: Int) {

def lookup(v: source.SEVarLevel, absReconstructed: Abs): target.SELoc = {
val source.SEVarLevel(absProvided, _) = v
if (Abs(absProvided) != absReconstructed) { //NICK: remove runtime check
sys.error(s"**lookup($absProvided / $absReconstructed) -- DIFF")
}
//val abs = Abs(absProvided)
val abs = absReconstructed
mapping.get(abs) match {
def lookup(v: SEVarLevel): target.SELoc = {
mapping.get(v) match {
case Some(loc) => loc
case None =>
throw sys.error(s"lookup($abs),in:$mapping")
case None => sys.error(s"lookup($v),in:$mapping")
}
}

def extend(n: Int): Env = {
// Create mappings for `n` new stack items, and combine with the (unshifted!) existing mapping.
val m2 = (0 until n).view.map { i =>
val abs = Abs(sourceDepth + i)
(abs, target.SELocAbsoluteS(targetDepth + i))
val v = SEVarLevel(sourceDepth + i)
(v, target.SELocAbsoluteS(targetDepth + i))
}
Env(sourceDepth + n, mapping ++ m2, targetDepth + n)
}

def absBody(arity: Int, fvs: List[(source.SEVarLevel, Abs)]): Env = {
val newRemapsF: Map[Abs, target.SELoc] =
fvs.view.zipWithIndex.map { case ((v @ _, abs), i) =>
abs -> target.SELocF(i)
def absBody(arity: Int, freeVars: List[SEVarLevel]): Env = {
val newRemapsF =
freeVars.view.zipWithIndex.map { case (v, i) =>
v -> target.SELocF(i)
}.toMap
val newRemapsA = (0 until arity).view.map { case i =>
val abs = Abs(sourceDepth + i)
abs -> target.SELocA(i)
val v = SEVarLevel(sourceDepth + i)
v -> target.SELocA(i)
}
// The keys in newRemapsF and newRemapsA are disjoint
val m1 = newRemapsF ++ newRemapsA
Expand Down Expand Up @@ -119,7 +112,7 @@ private[speedy] object ClosureConversion {

final case class Location(loc: Ref.Location) extends Cont

final case class Abs(arity: Int, fvs: List[target.SELoc]) extends Cont
final case class Abs(arity: Int, freeLocs: List[target.SELoc]) extends Cont

final case class App1(env: Env, args: List[source.SExpr]) extends Cont

Expand Down Expand Up @@ -183,9 +176,8 @@ private[speedy] object ClosureConversion {
// Going Down: match on expression form...
case Down(exp, env) =>
exp match {
case v @ source.SEVarLevel(_, rel) =>
val abs = Abs(env.sourceDepth - rel)
loop(Up(env.lookup(v, abs)), conts)
case v: SEVarLevel =>
loop(Up(env.lookup(v)), conts)

case source.SEVal(x) => loop(Up(target.SEVal(x)), conts)
case source.SEBuiltin(x) => loop(Up(target.SEBuiltin(x)), conts)
Expand All @@ -195,12 +187,10 @@ private[speedy] object ClosureConversion {
loop(Down(body, env), Cont.Location(loc) :: conts)

case source.SEAbs(arity, body) =>
val fvsAsListAbs: List[(source.SEVarLevel, Abs)] =
freeVars(body, arity, env.sourceDepth).toList.sortBy(_._2).map { case (v, rel) =>
(v, Abs(env.sourceDepth - rel))
}
val fvs = fvsAsListAbs.map { case (v, abs) => env.lookup(v, abs) }
loop(Down(body, env.absBody(arity, fvsAsListAbs)), Cont.Abs(arity, fvs) :: conts)
val freeVars =
computeFreeVars(body, env.sourceDepth).toList.sortBy(_.level)
val freeLocs = freeVars.map { v => env.lookup(v) }
loop(Down(body, env.absBody(arity, freeVars)), Cont.Abs(arity, freeLocs) :: conts)

case source.SEApp(fun, args) =>
loop(Down(fun, env), Cont.App1(env, args) :: conts)
Expand Down Expand Up @@ -242,9 +232,9 @@ private[speedy] object ClosureConversion {
val body = result
loop(Up(target.SELocation(loc, body)), conts)

case Cont.Abs(arity, fvs) =>
case Cont.Abs(arity, freeLocs) =>
val body = result
loop(Up(target.SEMakeClo(fvs, arity, body)), conts)
loop(Up(target.SEMakeClo(freeLocs, arity, body)), conts)

case Cont.App1(env, args) =>
val fun = result
Expand Down Expand Up @@ -327,65 +317,40 @@ private[speedy] object ClosureConversion {
loop(Down(source0, Env()), Nil)
}

/** Compute the free variables in a speedy expression.
* The returned free variables are de bruijn indices adjusted to the stack of the caller.
*/

type Res = (source.SEVarLevel, Int)

private[this] def freeVars(expr0: source.SExpr, depth0: Int, sourceDepth: Int): Set[Res] = {
/** Compute the free variables of a speedy expression */
private[this] def computeFreeVars(expr0: source.SExpr, sourceDepth: Int): Set[SEVarLevel] = {
@tailrec // woo hoo, stack safe!
def go(acc: Set[Res], work: List[(source.SExpr, Int)]): Set[Res] = { //NICK: remove depth from work-list
def go(acc: Set[SEVarLevel], work: List[source.SExpr]): Set[SEVarLevel] = {
// 'acc' is the (accumulated) set of free variables we have found so far.
// 'work' is a list of source expressions (paired with their depth) which we still have to process.
// 'work' is a list of source expressions which we still have to process.
work match {
case Nil => acc // final result
case (expr, depth) :: work => {
case expr :: work => {
expr match {
case v @ source.SEVarLevel(abs, rel) =>
val isFree1: Boolean = rel > depth // old
val isFree2: Boolean = abs < sourceDepth //compute without use of rel
if (isFree1 != isFree2) { //NICK: remove runtime check
sys.error(s"**freeVars: ($isFree1 / $isFree2) -- DIFF")
}
if (isFree2) {
val callerRel = rel - depth // adjust to caller's environment
val x: Res = (v, callerRel)
go(acc + x, work)
case v @ SEVarLevel(level) =>
if (level < sourceDepth) {
go(acc + v, work)
} else {
go(acc, work)
}
case _: source.SEVal => go(acc, work)
case _: source.SEBuiltin => go(acc, work)
case _: source.SEValue => go(acc, work)
case source.SELocation(_, body) =>
go(acc, (body, depth) :: work)
case source.SEApp(fun, args) =>
go(acc, (fun :: args).map(e => (e, depth)) ++ work)
case source.SEAbs(n, body) =>
go(acc, (body, depth + n) :: work)
case source.SELocation(_, body) => go(acc, body :: work)
case source.SEApp(fun, args) => go(acc, fun :: args ++ work)
case source.SEAbs(_, body) => go(acc, body :: work)
case source.SECase(scrut, alts) =>
val moreWork = alts.map { case source.SCaseAlt(pat, body) =>
val n = pat.numArgs
(body, depth + n)
}
go(acc, (scrut, depth) :: moreWork ++ work)
case source.SELet(bounds, body) =>
val moreWork = bounds.zipWithIndex.map { case (bound, n) =>
(bound, depth + n)
}
go(acc, (body, depth + bounds.length) :: moreWork ++ work)
case source.SELabelClosure(_, expr) =>
go(acc, (expr, depth) :: work)
case source.SETryCatch(body, handler) =>
go(acc, (handler, 1 + depth) :: (body, depth) :: work)
case source.SEScopeExercise(body) =>
go(acc, (body, depth) :: work)
val bodies = alts.map { case source.SCaseAlt(_, body) => body }
go(acc, scrut :: bodies ++ work)
case source.SELet(bounds, body) => go(acc, body :: bounds ++ work)
case source.SELabelClosure(_, expr) => go(acc, expr :: work)
case source.SETryCatch(body, handler) => go(acc, handler :: body :: work)
case source.SEScopeExercise(body) => go(acc, body :: work)
}
}
}
}
go(Set.empty, List((expr0, depth0)))
go(Set.empty, List(expr0))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import scala.reflect.ClassTag

/** Compiles LF expressions into Speedy expressions.
* This includes:
* - Writing variable references into de Bruijn indices.
* - Translating variable references into de Bruijn levels.
* - Closure conversion: EAbs turns into SEMakeClo, which creates a closure by copying free variables into a closure object.
* - Rewriting of update and scenario actions into applications of builtin functions that take an "effect" token.
*
Expand Down Expand Up @@ -174,7 +174,7 @@ private[lf] final class Compiler(
varIndices: Map[VarRef, Position],
) {

def toSEVar(p: Position): s.SExpr = s.SEVarLevel(p.idx, position - p.idx)
def toSEVar(p: Position) = s.SEVarLevel(p.idx)

def nextPosition = Position(position)

Expand Down Expand Up @@ -575,19 +575,14 @@ private[lf] final class Compiler(
}

@inline
private[this] def compileIdentity(env: Env): s.SExpr =
s.SEAbs(1, s.SEVarLevel(env.position, 1))
private[this] def compileIdentity(env: Env) = s.SEAbs(1, s.SEVarLevel(env.position))

@inline
private[this] def compileBuiltin(env: Env, bf: BuiltinFunction): s.SExpr = {

def SBCompareNumeric(b: SBuiltinPure) = {
val d = env.position
// TODO: It would be better to have new builtins than to manufacture syntax like this
s.SEAbs(
3,
s.SEApp(s.SEBuiltin(b), List(s.SEVarLevel(d + 1, 2), s.SEVarLevel(d + 2, 1))),
)
s.SEAbs(3, s.SEApp(s.SEBuiltin(b), List(s.SEVarLevel(d + 1), s.SEVarLevel(d + 2))))
}

val SBLessNumeric = SBCompareNumeric(SBLess)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ package speedy
*
* 1: convert from LF
* - reducing binding forms (update & scenario becoming builtins)
* - moving to de Bruijn indexing for variable
* - convert named variables to de Bruijn levels
* - moving to multi-argument applications and abstractions.
*
* 2: closure conversion
Expand Down Expand Up @@ -58,7 +58,7 @@ private[speedy] object SExpr0 {
* https://en.wikipedia.org/wiki/De_Bruijn_index
* This expression form is only allowed prior to closure conversion
*/
final case class SEVarLevel(level: Int, index: Int) extends SExpr //NICK: kill index
final case class SEVarLevel(level: Int) extends SExpr

/** Reference to a value. On first lookup the evaluated expression is
* stored in 'cached'.
Expand Down

0 comments on commit d051ead

Please sign in to comment.