Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use DeBruijn levels in SExpr0. #11987

Merged
merged 1 commit into from
Dec 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -90,20 +90,21 @@ object PlaySpeedy {
def examples: Map[String, (Int, SExpr)] = {

def num(n: Long): SExpr = SEValue(SInt64(n))
def mkVar(level: Int) = SEVarLevel(level)

// The trailing numeral is the number of args at the scala level
// The trailing numeral is the number of args at the scala mkVar

def decrement1(x: SExpr): SExpr = SEApp(SEBuiltin(SBSubInt64), List(x, SEValue(SInt64(1))))
val decrement = SEAbs(1, decrement1(SEVar(1)))
val decrement = SEAbs(1, decrement1(mkVar(0)))

def subtract2(x: SExpr, y: SExpr): SExpr = SEApp(SEBuiltin(SBSubInt64), List(x, y))
val subtract = SEAbs(2, subtract2(SEVar(2), SEVar(1)))
val subtract = SEAbs(2, subtract2(mkVar(0), mkVar(1)))

def twice2(f: SExpr, x: SExpr): SExpr = SEApp(f, List(SEApp(f, List(x))))
val twice = SEAbs(2, twice2(SEVar(2), SEVar(1)))
val twice = SEAbs(2, twice2(mkVar(3), mkVar(4)))

def thrice2(f: SExpr, x: SExpr): SExpr = SEApp(f, List(SEApp(f, List(SEApp(f, List(x))))))
val thrice = SEAbs(2, thrice2(SEVar(2), SEVar(1)))
val thrice = SEAbs(2, thrice2(mkVar(0), mkVar(1)))
Comment on lines 97 to +107
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not very important, but I do not really like the use of global index in this class.
How are they calculated ? ...
I assume the level is somehow global, and those expression should be use in a very specific place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree.
Actually, I think there is very little value in this class. I should delete it.
It's left over form my very first explorations into speedy.


val examples = List(
(
Expand Down Expand Up @@ -142,7 +143,10 @@ object PlaySpeedy {
List(num(21)),
SEApp(
twice,
List(SEAbs(1, subtract2(SEVar(1), subtract2(SEVar(4), SEVar(2)))), SEVar(2)),
List(
SEAbs(1, subtract2(mkVar(3), subtract2(mkVar(0), mkVar(2)))),
mkVar(1),
),
),
), //100
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,36 @@ import scala.annotation.tailrec

private[speedy] object ClosureConversion {

private[speedy] def closureConvert(source0: source.SExpr): target.SExpr = {
import source.SEVarLevel

case class Abs(a: Int) // absolute variable index, determined by tracking sourceDepth
private[speedy] def closureConvert(source0: source.SExpr): target.SExpr = {

case class Env(sourceDepth: Int, mapping: Map[Abs, target.SELoc], targetDepth: Int) {
case class Env(sourceDepth: Int, mapping: Map[SEVarLevel, target.SELoc], targetDepth: Int) {

def lookup(abs: Abs): target.SELoc = {
mapping.get(abs) match {
def lookup(v: SEVarLevel): target.SELoc = {
mapping.get(v) match {
case Some(loc) => loc
case None =>
throw sys.error(s"lookup($abs),in:$mapping")
case None => sys.error(s"lookup($v),in:$mapping")
Comment on lines +28 to +30
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not using getOrElse ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤷‍♂️

}
}

def extend(n: Int): Env = {
// Create mappings for `n` new stack items, and combine with the (unshifted!) existing mapping.
val m2 = (0 until n).view.map { i =>
val abs = Abs(sourceDepth + i)
(abs, target.SELocAbsoluteS(targetDepth + i))
val v = SEVarLevel(sourceDepth + i)
(v, target.SELocAbsoluteS(targetDepth + i))
}
Env(sourceDepth + n, mapping ++ m2, targetDepth + n)
}

def absBody(arity: Int, fvs: List[Abs]): Env = {
val newRemapsF: Map[Abs, target.SELoc] = fvs.view.zipWithIndex.map { case (abs, i) =>
abs -> target.SELocF(i)
}.toMap
def absBody(arity: Int, freeVars: List[SEVarLevel]): Env = {
val newRemapsF =
freeVars.view.zipWithIndex.map { case (v, i) =>
v -> target.SELocF(i)
}.toMap
val newRemapsA = (0 until arity).view.map { case i =>
val abs = Abs(sourceDepth + i)
abs -> target.SELocA(i)
val v = SEVarLevel(sourceDepth + i)
v -> target.SELocA(i)
}
// The keys in newRemapsF and newRemapsA are disjoint
val m1 = newRemapsF ++ newRemapsA
Expand Down Expand Up @@ -85,7 +85,7 @@ private[speedy] object ClosureConversion {
* forms correspond to the source expression forms: specifically, the location of
* recursive expression instances (values of type SExpr).
*
* For expression forms with no recursive instance (i.e. SEVar, SEVal), there are
* For expression forms with no recursive instance (i.e. SEVarLevel, SEVal), there are
* no corresponding continuation forms.
*
* For expression forms with a single recursive instance (i.e. SELocation), there
Expand All @@ -112,7 +112,7 @@ private[speedy] object ClosureConversion {

final case class Location(loc: Ref.Location) extends Cont

final case class Abs(arity: Int, fvs: List[target.SELoc]) extends Cont
final case class Abs(arity: Int, freeLocs: List[target.SELoc]) extends Cont

final case class App1(env: Env, args: List[source.SExpr]) extends Cont

Expand Down Expand Up @@ -176,9 +176,8 @@ private[speedy] object ClosureConversion {
// Going Down: match on expression form...
case Down(exp, env) =>
exp match {
case source.SEVar(r) =>
val abs = Abs(env.sourceDepth - r)
loop(Up(env.lookup(abs)), conts)
case v: SEVarLevel =>
loop(Up(env.lookup(v)), conts)

case source.SEVal(x) => loop(Up(target.SEVal(x)), conts)
case source.SEBuiltin(x) => loop(Up(target.SEBuiltin(x)), conts)
Expand All @@ -188,11 +187,10 @@ private[speedy] object ClosureConversion {
loop(Down(body, env), Cont.Location(loc) :: conts)

case source.SEAbs(arity, body) =>
val fvsAsListAbs = freeVars(body, arity).toList.sorted.map { r =>
Abs(env.sourceDepth - r)
}
val fvs = fvsAsListAbs.map { abs => env.lookup(abs) }
loop(Down(body, env.absBody(arity, fvsAsListAbs)), Cont.Abs(arity, fvs) :: conts)
val freeVars =
computeFreeVars(body, env.sourceDepth).toList.sortBy(_.level)
val freeLocs = freeVars.map { v => env.lookup(v) }
loop(Down(body, env.absBody(arity, freeVars)), Cont.Abs(arity, freeLocs) :: conts)

case source.SEApp(fun, args) =>
loop(Down(fun, env), Cont.App1(env, args) :: conts)
Expand Down Expand Up @@ -234,9 +232,9 @@ private[speedy] object ClosureConversion {
val body = result
loop(Up(target.SELocation(loc, body)), conts)

case Cont.Abs(arity, fvs) =>
case Cont.Abs(arity, freeLocs) =>
val body = result
loop(Up(target.SEMakeClo(fvs, arity, body)), conts)
loop(Up(target.SEMakeClo(freeLocs, arity, body)), conts)

case Cont.App1(env, args) =>
val fun = result
Expand Down Expand Up @@ -319,56 +317,40 @@ private[speedy] object ClosureConversion {
loop(Down(source0, Env()), Nil)
}

/** Compute the free variables in a speedy expression.
* The returned free variables are de bruijn indices adjusted to the stack of the caller.
*/
private[this] def freeVars(expr0: source.SExpr, depth0: Int): Set[Int] = {
/** Compute the free variables of a speedy expression */
private[this] def computeFreeVars(expr0: source.SExpr, sourceDepth: Int): Set[SEVarLevel] = {
@tailrec // woo hoo, stack safe!
def go(acc: Set[Int], work: List[(source.SExpr, Int)]): Set[Int] = {
def go(acc: Set[SEVarLevel], work: List[source.SExpr]): Set[SEVarLevel] = {
// 'acc' is the (accumulated) set of free variables we have found so far.
// 'work' is a list of source expressions (paired with their depth) which we still have to process.
// 'work' is a list of source expressions which we still have to process.
work match {
case Nil => acc // final result
case (expr, depth) :: work => {
case expr :: work => {
expr match {
case source.SEVar(rel) =>
if (rel > depth) {
val callerRel = rel - depth // adjust to caller's environment
go(acc + callerRel, work)
case v @ SEVarLevel(level) =>
if (level < sourceDepth) {
go(acc + v, work)
} else {
go(acc, work)
}
case _: source.SEVal => go(acc, work)
case _: source.SEBuiltin => go(acc, work)
case _: source.SEValue => go(acc, work)
case source.SELocation(_, body) =>
go(acc, (body, depth) :: work)
case source.SEApp(fun, args) =>
go(acc, (fun :: args).map(e => (e, depth)) ++ work)
case source.SEAbs(n, body) =>
go(acc, (body, depth + n) :: work)
case source.SELocation(_, body) => go(acc, body :: work)
case source.SEApp(fun, args) => go(acc, fun :: args ++ work)
case source.SEAbs(_, body) => go(acc, body :: work)
case source.SECase(scrut, alts) =>
val moreWork = alts.map { case source.SCaseAlt(pat, body) =>
val n = pat.numArgs
(body, depth + n)
}
go(acc, (scrut, depth) :: moreWork ++ work)
case source.SELet(bounds, body) =>
val moreWork = bounds.zipWithIndex.map { case (bound, n) =>
(bound, depth + n)
}
go(acc, (body, depth + bounds.length) :: moreWork ++ work)
case source.SELabelClosure(_, expr) =>
go(acc, (expr, depth) :: work)
case source.SETryCatch(body, handler) =>
go(acc, (handler, 1 + depth) :: (body, depth) :: work)
case source.SEScopeExercise(body) =>
go(acc, (body, depth) :: work)
val bodies = alts.map { case source.SCaseAlt(_, body) => body }
go(acc, scrut :: bodies ++ work)
case source.SELet(bounds, body) => go(acc, body :: bounds ++ work)
case source.SELabelClosure(_, expr) => go(acc, expr :: work)
case source.SETryCatch(body, handler) => go(acc, handler :: body :: work)
case source.SEScopeExercise(body) => go(acc, body :: work)
}
}
}
}
go(Set.empty, List((expr0, depth0)))
go(Set.empty, List(expr0))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import scala.reflect.ClassTag

/** Compiles LF expressions into Speedy expressions.
* This includes:
* - Writing variable references into de Bruijn indices.
* - Translating variable references into de Bruijn levels.
* - Closure conversion: EAbs turns into SEMakeClo, which creates a closure by copying free variables into a closure object.
* - Rewriting of update and scenario actions into applications of builtin functions that take an "effect" token.
*
Expand Down Expand Up @@ -90,14 +90,6 @@ private[lf] object Compiler {

private val SEGetTime = s.SEBuiltin(SBGetTime)

private def SBCompareNumeric(b: SBuiltinPure) =
s.SEAbs(3, s.SEApp(s.SEBuiltin(b), List(s.SEVar(2), s.SEVar(1))))
private val SBLessNumeric = SBCompareNumeric(SBLess)
private val SBLessEqNumeric = SBCompareNumeric(SBLessEq)
private val SBGreaterNumeric = SBCompareNumeric(SBGreater)
private val SBGreaterEqNumeric = SBCompareNumeric(SBGreaterEq)
private val SBEqualNumeric = SBCompareNumeric(SBEqual)

private val SBEToTextNumeric = s.SEAbs(1, s.SEBuiltin(SBToText))

private val SENat: Numeric.Scale => Some[s.SEValue] =
Expand Down Expand Up @@ -182,7 +174,7 @@ private[lf] final class Compiler(
varIndices: Map[VarRef, Position],
) {

def toSEVar(p: Position): s.SEVar = s.SEVar(position - p.idx)
def toSEVar(p: Position) = s.SEVarLevel(p.idx)

def nextPosition = Position(position)

Expand Down Expand Up @@ -214,14 +206,14 @@ private[lf] final class Compiler(

private[this] def vars: List[VarRef] = varIndices.keys.toList

private[this] def lookupVar(varRef: VarRef): Option[s.SEVar] =
private[this] def lookupVar(varRef: VarRef): Option[s.SExpr] =
varIndices.get(varRef).map(toSEVar)

def lookupExprVar(name: ExprVarName): s.SEVar =
def lookupExprVar(name: ExprVarName): s.SExpr =
lookupVar(EVarRef(name))
.getOrElse(throw CompilationError(s"Unknown variable: $name. Known: ${vars.mkString(",")}"))

def lookupTypeVar(name: TypeVarName): Option[s.SEVar] =
def lookupTypeVar(name: TypeVarName): Option[s.SExpr] =
lookupVar(TVarRef(name))

}
Expand Down Expand Up @@ -477,7 +469,7 @@ private[lf] final class Compiler(
case EVal(ref) =>
s.SEVal(t.LfDefRef(ref))
case EBuiltin(bf) =>
compileBuiltin(bf)
compileBuiltin(env, bf)
case EPrimCon(con) =>
compilePrimCon(con)
case EPrimLit(lit) =>
Expand Down Expand Up @@ -587,9 +579,24 @@ private[lf] final class Compiler(
}

@inline
private[this] def compileBuiltin(bf: BuiltinFunction): s.SExpr =
private[this] def compileIdentity(env: Env) = s.SEAbs(1, s.SEVarLevel(env.position))

@inline
private[this] def compileBuiltin(env: Env, bf: BuiltinFunction): s.SExpr = {

def SBCompareNumeric(b: SBuiltinPure) = {
val d = env.position
s.SEAbs(3, s.SEApp(s.SEBuiltin(b), List(s.SEVarLevel(d + 1), s.SEVarLevel(d + 2))))
}

val SBLessNumeric = SBCompareNumeric(SBLess)
val SBLessEqNumeric = SBCompareNumeric(SBLessEq)
val SBGreaterNumeric = SBCompareNumeric(SBGreater)
val SBGreaterEqNumeric = SBCompareNumeric(SBGreaterEq)
val SBEqualNumeric = SBCompareNumeric(SBEqual)

bf match {
case BCoerceContractId => s.SEAbs.identity
case BCoerceContractId => compileIdentity(env)
// Numeric Comparisons
case BLessNumeric => SBLessNumeric
case BLessEqNumeric => SBLessEqNumeric
Expand Down Expand Up @@ -712,6 +719,7 @@ private[lf] final class Compiler(
case BAnyExceptionMessage => SBAnyExceptionMessage
})
}
}

@inline
private[this] def compilePrimCon(con: PrimCon): s.SExpr =
Expand Down Expand Up @@ -1345,7 +1353,7 @@ private[lf] final class Compiler(
ifaceId: Identifier,
): (t.SDefinitionRef, SDefinition) =
t.ImplementsDefRef(tmplId, ifaceId) ->
SDefinition(unsafeClosureConvert(s.SEAbs.identity))
SDefinition(unsafeClosureConvert(compileIdentity(Env.Empty)))

// Compile the implementation of an interface method.
private[this] def compileImplementsMethod(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ package speedy
*
* 1: convert from LF
* - reducing binding forms (update & scenario becoming builtins)
* - moving to de Bruijn indexing for variable
* - convert named variables to de Bruijn levels
* - moving to multi-argument applications and abstractions.
*
* 2: closure conversion
Expand Down Expand Up @@ -54,13 +54,11 @@ private[speedy] object SExpr0 {

sealed abstract class SExpr extends Product with Serializable

/** Reference to a variable. 'index' is the 1-based de Bruijn index,
* that is, SEVar(1) points to the nearest enclosing variable binder.
* which could be an SELam, SELet, or a binding variant of SECasePat.
/** Reference to a variable. 'level' is the 0-based de Bruijn LEVEL (not INDEX)
* https://en.wikipedia.org/wiki/De_Bruijn_index
* This expression form is only allowed prior to closure conversion
*/
final case class SEVar(index: Int) extends SExpr
final case class SEVarLevel(level: Int) extends SExpr

/** Reference to a value. On first lookup the evaluated expression is
* stored in 'cached'.
Expand All @@ -81,14 +79,6 @@ private[speedy] object SExpr0 {
/** Lambda abstraction. Transformed to SEMakeClo during closure conversion */
final case class SEAbs(arity: Int, body: SExpr) extends SExpr

object SEAbs {
// Helper for constructing abstraction expressions:
// SEAbs(1) { ... }
def apply(arity: Int)(body: SExpr): SExpr = SEAbs(arity, body)

val identity: SEAbs = SEAbs(1, SEVar(1))
}

/** Pattern match. */
final case class SECase(scrut: SExpr, alts: List[SCaseAlt]) extends SExpr

Expand Down