Skip to content

Commit

Permalink
Move to using DB-levels (not DB-indexes) in SExpr0. This change inclu…
Browse files Browse the repository at this point in the history
…des both index and level, and performs a runtime check.

changelog_begin
changelog_end

add runtime check in freeVars: determination that a variable is-free using levels instead of indexes

remove DB-indexes and runtime check; simplify freeVars computation in closure-conversion
  • Loading branch information
nickchapman-da committed Dec 9, 2021
1 parent 39eca49 commit 1783596
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,20 +90,21 @@ object PlaySpeedy {
def examples: Map[String, (Int, SExpr)] = {

def num(n: Long): SExpr = SEValue(SInt64(n))
def mkVar(level: Int) = SEVarLevel(level)

// The trailing numeral is the number of args at the scala level
// The trailing numeral is the number of args at the scala mkVar

def decrement1(x: SExpr): SExpr = SEApp(SEBuiltin(SBSubInt64), List(x, SEValue(SInt64(1))))
val decrement = SEAbs(1, decrement1(SEVar(1)))
val decrement = SEAbs(1, decrement1(mkVar(0)))

def subtract2(x: SExpr, y: SExpr): SExpr = SEApp(SEBuiltin(SBSubInt64), List(x, y))
val subtract = SEAbs(2, subtract2(SEVar(2), SEVar(1)))
val subtract = SEAbs(2, subtract2(mkVar(0), mkVar(1)))

def twice2(f: SExpr, x: SExpr): SExpr = SEApp(f, List(SEApp(f, List(x))))
val twice = SEAbs(2, twice2(SEVar(2), SEVar(1)))
val twice = SEAbs(2, twice2(mkVar(3), mkVar(4)))

def thrice2(f: SExpr, x: SExpr): SExpr = SEApp(f, List(SEApp(f, List(SEApp(f, List(x))))))
val thrice = SEAbs(2, thrice2(SEVar(2), SEVar(1)))
val thrice = SEAbs(2, thrice2(mkVar(0), mkVar(1)))

val examples = List(
(
Expand Down Expand Up @@ -142,7 +143,10 @@ object PlaySpeedy {
List(num(21)),
SEApp(
twice,
List(SEAbs(1, subtract2(SEVar(1), subtract2(SEVar(4), SEVar(2)))), SEVar(2)),
List(
SEAbs(1, subtract2(mkVar(3), subtract2(mkVar(0), mkVar(2)))),
mkVar(1),
),
),
), //100
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,36 @@ import scala.annotation.tailrec

private[speedy] object ClosureConversion {

private[speedy] def closureConvert(source0: source.SExpr): target.SExpr = {
import source.SEVarLevel

case class Abs(a: Int) // absolute variable index, determined by tracking sourceDepth
private[speedy] def closureConvert(source0: source.SExpr): target.SExpr = {

case class Env(sourceDepth: Int, mapping: Map[Abs, target.SELoc], targetDepth: Int) {
case class Env(sourceDepth: Int, mapping: Map[SEVarLevel, target.SELoc], targetDepth: Int) {

def lookup(abs: Abs): target.SELoc = {
mapping.get(abs) match {
def lookup(v: SEVarLevel): target.SELoc = {
mapping.get(v) match {
case Some(loc) => loc
case None =>
throw sys.error(s"lookup($abs),in:$mapping")
case None => sys.error(s"lookup($v),in:$mapping")
}
}

def extend(n: Int): Env = {
// Create mappings for `n` new stack items, and combine with the (unshifted!) existing mapping.
val m2 = (0 until n).view.map { i =>
val abs = Abs(sourceDepth + i)
(abs, target.SELocAbsoluteS(targetDepth + i))
val v = SEVarLevel(sourceDepth + i)
(v, target.SELocAbsoluteS(targetDepth + i))
}
Env(sourceDepth + n, mapping ++ m2, targetDepth + n)
}

def absBody(arity: Int, fvs: List[Abs]): Env = {
val newRemapsF: Map[Abs, target.SELoc] = fvs.view.zipWithIndex.map { case (abs, i) =>
abs -> target.SELocF(i)
}.toMap
def absBody(arity: Int, freeVars: List[SEVarLevel]): Env = {
val newRemapsF =
freeVars.view.zipWithIndex.map { case (v, i) =>
v -> target.SELocF(i)
}.toMap
val newRemapsA = (0 until arity).view.map { case i =>
val abs = Abs(sourceDepth + i)
abs -> target.SELocA(i)
val v = SEVarLevel(sourceDepth + i)
v -> target.SELocA(i)
}
// The keys in newRemapsF and newRemapsA are disjoint
val m1 = newRemapsF ++ newRemapsA
Expand Down Expand Up @@ -85,7 +85,7 @@ private[speedy] object ClosureConversion {
* forms correspond to the source expression forms: specifically, the location of
* recursive expression instances (values of type SExpr).
*
* For expression forms with no recursive instance (i.e. SEVar, SEVal), there are
* For expression forms with no recursive instance (i.e. SEVarLevel, SEVal), there are
* no corresponding continuation forms.
*
* For expression forms with a single recursive instance (i.e. SELocation), there
Expand All @@ -112,7 +112,7 @@ private[speedy] object ClosureConversion {

final case class Location(loc: Ref.Location) extends Cont

final case class Abs(arity: Int, fvs: List[target.SELoc]) extends Cont
final case class Abs(arity: Int, freeLocs: List[target.SELoc]) extends Cont

final case class App1(env: Env, args: List[source.SExpr]) extends Cont

Expand Down Expand Up @@ -176,9 +176,8 @@ private[speedy] object ClosureConversion {
// Going Down: match on expression form...
case Down(exp, env) =>
exp match {
case source.SEVar(r) =>
val abs = Abs(env.sourceDepth - r)
loop(Up(env.lookup(abs)), conts)
case v: SEVarLevel =>
loop(Up(env.lookup(v)), conts)

case source.SEVal(x) => loop(Up(target.SEVal(x)), conts)
case source.SEBuiltin(x) => loop(Up(target.SEBuiltin(x)), conts)
Expand All @@ -188,11 +187,10 @@ private[speedy] object ClosureConversion {
loop(Down(body, env), Cont.Location(loc) :: conts)

case source.SEAbs(arity, body) =>
val fvsAsListAbs = freeVars(body, arity).toList.sorted.map { r =>
Abs(env.sourceDepth - r)
}
val fvs = fvsAsListAbs.map { abs => env.lookup(abs) }
loop(Down(body, env.absBody(arity, fvsAsListAbs)), Cont.Abs(arity, fvs) :: conts)
val freeVars =
computeFreeVars(body, env.sourceDepth).toList.sortBy(_.level)
val freeLocs = freeVars.map { v => env.lookup(v) }
loop(Down(body, env.absBody(arity, freeVars)), Cont.Abs(arity, freeLocs) :: conts)

case source.SEApp(fun, args) =>
loop(Down(fun, env), Cont.App1(env, args) :: conts)
Expand Down Expand Up @@ -234,9 +232,9 @@ private[speedy] object ClosureConversion {
val body = result
loop(Up(target.SELocation(loc, body)), conts)

case Cont.Abs(arity, fvs) =>
case Cont.Abs(arity, freeLocs) =>
val body = result
loop(Up(target.SEMakeClo(fvs, arity, body)), conts)
loop(Up(target.SEMakeClo(freeLocs, arity, body)), conts)

case Cont.App1(env, args) =>
val fun = result
Expand Down Expand Up @@ -319,56 +317,40 @@ private[speedy] object ClosureConversion {
loop(Down(source0, Env()), Nil)
}

/** Compute the free variables in a speedy expression.
* The returned free variables are de bruijn indices adjusted to the stack of the caller.
*/
private[this] def freeVars(expr0: source.SExpr, depth0: Int): Set[Int] = {
/** Compute the free variables of a speedy expression */
private[this] def computeFreeVars(expr0: source.SExpr, sourceDepth: Int): Set[SEVarLevel] = {
@tailrec // woo hoo, stack safe!
def go(acc: Set[Int], work: List[(source.SExpr, Int)]): Set[Int] = {
def go(acc: Set[SEVarLevel], work: List[source.SExpr]): Set[SEVarLevel] = {
// 'acc' is the (accumulated) set of free variables we have found so far.
// 'work' is a list of source expressions (paired with their depth) which we still have to process.
// 'work' is a list of source expressions which we still have to process.
work match {
case Nil => acc // final result
case (expr, depth) :: work => {
case expr :: work => {
expr match {
case source.SEVar(rel) =>
if (rel > depth) {
val callerRel = rel - depth // adjust to caller's environment
go(acc + callerRel, work)
case v @ SEVarLevel(level) =>
if (level < sourceDepth) {
go(acc + v, work)
} else {
go(acc, work)
}
case _: source.SEVal => go(acc, work)
case _: source.SEBuiltin => go(acc, work)
case _: source.SEValue => go(acc, work)
case source.SELocation(_, body) =>
go(acc, (body, depth) :: work)
case source.SEApp(fun, args) =>
go(acc, (fun :: args).map(e => (e, depth)) ++ work)
case source.SEAbs(n, body) =>
go(acc, (body, depth + n) :: work)
case source.SELocation(_, body) => go(acc, body :: work)
case source.SEApp(fun, args) => go(acc, fun :: args ++ work)
case source.SEAbs(_, body) => go(acc, body :: work)
case source.SECase(scrut, alts) =>
val moreWork = alts.map { case source.SCaseAlt(pat, body) =>
val n = pat.numArgs
(body, depth + n)
}
go(acc, (scrut, depth) :: moreWork ++ work)
case source.SELet(bounds, body) =>
val moreWork = bounds.zipWithIndex.map { case (bound, n) =>
(bound, depth + n)
}
go(acc, (body, depth + bounds.length) :: moreWork ++ work)
case source.SELabelClosure(_, expr) =>
go(acc, (expr, depth) :: work)
case source.SETryCatch(body, handler) =>
go(acc, (handler, 1 + depth) :: (body, depth) :: work)
case source.SEScopeExercise(body) =>
go(acc, (body, depth) :: work)
val bodies = alts.map { case source.SCaseAlt(_, body) => body }
go(acc, scrut :: bodies ++ work)
case source.SELet(bounds, body) => go(acc, body :: bounds ++ work)
case source.SELabelClosure(_, expr) => go(acc, expr :: work)
case source.SETryCatch(body, handler) => go(acc, handler :: body :: work)
case source.SEScopeExercise(body) => go(acc, body :: work)
}
}
}
}
go(Set.empty, List((expr0, depth0)))
go(Set.empty, List(expr0))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import scala.reflect.ClassTag

/** Compiles LF expressions into Speedy expressions.
* This includes:
* - Writing variable references into de Bruijn indices.
* - Translating variable references into de Bruijn levels.
* - Closure conversion: EAbs turns into SEMakeClo, which creates a closure by copying free variables into a closure object.
* - Rewriting of update and scenario actions into applications of builtin functions that take an "effect" token.
*
Expand Down Expand Up @@ -90,14 +90,6 @@ private[lf] object Compiler {

private val SEGetTime = s.SEBuiltin(SBGetTime)

private def SBCompareNumeric(b: SBuiltinPure) =
s.SEAbs(3, s.SEApp(s.SEBuiltin(b), List(s.SEVar(2), s.SEVar(1))))
private val SBLessNumeric = SBCompareNumeric(SBLess)
private val SBLessEqNumeric = SBCompareNumeric(SBLessEq)
private val SBGreaterNumeric = SBCompareNumeric(SBGreater)
private val SBGreaterEqNumeric = SBCompareNumeric(SBGreaterEq)
private val SBEqualNumeric = SBCompareNumeric(SBEqual)

private val SBEToTextNumeric = s.SEAbs(1, s.SEBuiltin(SBToText))

private val SENat: Numeric.Scale => Some[s.SEValue] =
Expand Down Expand Up @@ -182,7 +174,7 @@ private[lf] final class Compiler(
varIndices: Map[VarRef, Position],
) {

def toSEVar(p: Position): s.SEVar = s.SEVar(position - p.idx)
def toSEVar(p: Position) = s.SEVarLevel(p.idx)

def nextPosition = Position(position)

Expand Down Expand Up @@ -214,14 +206,14 @@ private[lf] final class Compiler(

private[this] def vars: List[VarRef] = varIndices.keys.toList

private[this] def lookupVar(varRef: VarRef): Option[s.SEVar] =
private[this] def lookupVar(varRef: VarRef): Option[s.SExpr] =
varIndices.get(varRef).map(toSEVar)

def lookupExprVar(name: ExprVarName): s.SEVar =
def lookupExprVar(name: ExprVarName): s.SExpr =
lookupVar(EVarRef(name))
.getOrElse(throw CompilationError(s"Unknown variable: $name. Known: ${vars.mkString(",")}"))

def lookupTypeVar(name: TypeVarName): Option[s.SEVar] =
def lookupTypeVar(name: TypeVarName): Option[s.SExpr] =
lookupVar(TVarRef(name))

}
Expand Down Expand Up @@ -477,7 +469,7 @@ private[lf] final class Compiler(
case EVal(ref) =>
s.SEVal(t.LfDefRef(ref))
case EBuiltin(bf) =>
compileBuiltin(bf)
compileBuiltin(env, bf)
case EPrimCon(con) =>
compilePrimCon(con)
case EPrimLit(lit) =>
Expand Down Expand Up @@ -587,9 +579,24 @@ private[lf] final class Compiler(
}

@inline
private[this] def compileBuiltin(bf: BuiltinFunction): s.SExpr =
private[this] def compileIdentity(env: Env) = s.SEAbs(1, s.SEVarLevel(env.position))

@inline
private[this] def compileBuiltin(env: Env, bf: BuiltinFunction): s.SExpr = {

def SBCompareNumeric(b: SBuiltinPure) = {
val d = env.position
s.SEAbs(3, s.SEApp(s.SEBuiltin(b), List(s.SEVarLevel(d + 1), s.SEVarLevel(d + 2))))
}

val SBLessNumeric = SBCompareNumeric(SBLess)
val SBLessEqNumeric = SBCompareNumeric(SBLessEq)
val SBGreaterNumeric = SBCompareNumeric(SBGreater)
val SBGreaterEqNumeric = SBCompareNumeric(SBGreaterEq)
val SBEqualNumeric = SBCompareNumeric(SBEqual)

bf match {
case BCoerceContractId => s.SEAbs.identity
case BCoerceContractId => compileIdentity(env)
// Numeric Comparisons
case BLessNumeric => SBLessNumeric
case BLessEqNumeric => SBLessEqNumeric
Expand Down Expand Up @@ -712,6 +719,7 @@ private[lf] final class Compiler(
case BAnyExceptionMessage => SBAnyExceptionMessage
})
}
}

@inline
private[this] def compilePrimCon(con: PrimCon): s.SExpr =
Expand Down Expand Up @@ -1345,7 +1353,7 @@ private[lf] final class Compiler(
ifaceId: Identifier,
): (t.SDefinitionRef, SDefinition) =
t.ImplementsDefRef(tmplId, ifaceId) ->
SDefinition(unsafeClosureConvert(s.SEAbs.identity))
SDefinition(unsafeClosureConvert(compileIdentity(Env.Empty)))

// Compile the implementation of an interface method.
private[this] def compileImplementsMethod(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ package speedy
*
* 1: convert from LF
* - reducing binding forms (update & scenario becoming builtins)
* - moving to de Bruijn indexing for variable
* - convert named variables to de Bruijn levels
* - moving to multi-argument applications and abstractions.
*
* 2: closure conversion
Expand Down Expand Up @@ -54,13 +54,11 @@ private[speedy] object SExpr0 {

sealed abstract class SExpr extends Product with Serializable

/** Reference to a variable. 'index' is the 1-based de Bruijn index,
* that is, SEVar(1) points to the nearest enclosing variable binder.
* which could be an SELam, SELet, or a binding variant of SECasePat.
/** Reference to a variable. 'level' is the 0-based de Bruijn LEVEL (not INDEX)
* https://en.wikipedia.org/wiki/De_Bruijn_index
* This expression form is only allowed prior to closure conversion
*/
final case class SEVar(index: Int) extends SExpr
final case class SEVarLevel(level: Int) extends SExpr

/** Reference to a value. On first lookup the evaluated expression is
* stored in 'cached'.
Expand All @@ -81,14 +79,6 @@ private[speedy] object SExpr0 {
/** Lambda abstraction. Transformed to SEMakeClo during closure conversion */
final case class SEAbs(arity: Int, body: SExpr) extends SExpr

object SEAbs {
// Helper for constructing abstraction expressions:
// SEAbs(1) { ... }
def apply(arity: Int)(body: SExpr): SExpr = SEAbs(arity, body)

val identity: SEAbs = SEAbs(1, SEVar(1))
}

/** Pattern match. */
final case class SECase(scrut: SExpr, alts: List[SCaseAlt]) extends SExpr

Expand Down

0 comments on commit 1783596

Please sign in to comment.