diff --git a/daml-lf/interpreter/src/main/scala/com/digitalasset/daml/lf/speedy/Compiler.scala b/daml-lf/interpreter/src/main/scala/com/digitalasset/daml/lf/speedy/Compiler.scala index edf84e9ef05e..6f57a311c061 100644 --- a/daml-lf/interpreter/src/main/scala/com/digitalasset/daml/lf/speedy/Compiler.scala +++ b/daml-lf/interpreter/src/main/scala/com/digitalasset/daml/lf/speedy/Compiler.scala @@ -167,45 +167,58 @@ private[lf] final class Compiler( case class Position(idx: Int) - private[this] def nextPosition(): Position = { - val p = env.position - env = env.copy(position = env.position + 1) - Position(p) + private[this] object Env { + val Empty = Env(0, Map.empty) } - private[this] def svar(p: Position): SEVar = SEVar(env.position - p.idx) + private[this] case class Env( + position: Int, + varIndices: Map[VarRef, Position], + ) { - private[this] def addVar(ref: VarRef, position: Position) = - env = env.copy(varIndices = env.varIndices.updated(ref, position)) + def toSEVar(p: Position): SEVar = SEVar(position - p.idx) - private[this] def addExprVar(name: ExprVarName, position: Position) = - addVar(EVarRef(name), position) + def nextPosition = Position(position) - private[this] def addTypeVar(name: TypeVarName, position: Position) = - addVar(TVarRef(name), position) + def pushVar: Env = copy(position = position + 1) - private[this] def hideTypeVar(name: TypeVarName) = - env = env.copy(varIndices = env.varIndices - TVarRef(name)) + private[this] def bindVar(ref: VarRef, p: Position) = + copy(varIndices = varIndices.updated(ref, p)) - private[this] def vars: List[VarRef] = env.varIndices.keys.toList + def pushVar(ref: VarRef): Env = + bindVar(ref, nextPosition).pushVar - private[this] def lookupVar(varRef: VarRef): Option[SEVar] = - env.varIndices.get(varRef).map(svar) + def pushExprVar(name: ExprVarName): Env = + pushVar(EVarRef(name)) - private[this] def lookupExprVar(name: ExprVarName): SEVar = - lookupVar(EVarRef(name)) - .getOrElse(throw CompilationError(s"Unknown variable: $name. Known: ${vars.mkString(",")}")) + def pushExprVar(maybeName: Option[ExprVarName]): Env = + maybeName match { + case Some(name) => pushExprVar(name) + case None => pushVar + } - private[this] def lookupTypeVar(name: TypeVarName): Option[SEVar] = - lookupVar(TVarRef(name)) + def pushTypeVar(name: ExprVarName): Env = + pushVar(TVarRef(name)) - private[this] case class Env( - position: Int = 0, - varIndices: Map[VarRef, Position] = Map.empty, - ) + def hideTypeVar(name: TypeVarName): Env = + copy(varIndices = varIndices - TVarRef(name)) + + def bindExprVar(name: ExprVarName, p: Position): Env = + bindVar(EVarRef(name), p) + + private[this] def vars: List[VarRef] = varIndices.keys.toList + + private[this] def lookupVar(varRef: VarRef): Option[SEVar] = + varIndices.get(varRef).map(toSEVar) + + def lookupExprVar(name: ExprVarName): SEVar = + lookupVar(EVarRef(name)) + .getOrElse(throw CompilationError(s"Unknown variable: $name. Known: ${vars.mkString(",")}")) + + def lookupTypeVar(name: TypeVarName): Option[SEVar] = + lookupVar(TVarRef(name)) - /** Environment mapping names into stack positions */ - private[this] var env = Env() + } private[this] val withLabel: (Profile.Label, SExpr) => SExpr = config.profiling match { @@ -231,39 +244,56 @@ private[lf] final class Compiler( private[this] def app(f: SExpr, a: SExpr) = SEApp(f, Array(a)) - private[this] def let(bound: SExpr)(f: Position => SExpr): SELet = - f(nextPosition()) match { + private[this] def let(env: Env, bound: SExpr)(f: (Position, Env) => SExpr): SELet = + f(env.nextPosition, env.pushVar) match { case SELet(bounds, body) => SELet(bound :: bounds, body) case otherwise => SELet(List(bound), otherwise) } - private[this] def unaryFunction(f: Position => SExpr): SEAbs = - withEnv { _ => - f(nextPosition()) - } match { + private[this] def unaryFunction(env: Env)(f: (Position, Env) => SExpr): SEAbs = + f(env.nextPosition, env.pushVar) match { case SEAbs(n, body) => SEAbs(n + 1, body) case otherwise => SEAbs(1, otherwise) } - private[this] def labeledUnaryFunction[L: Profile.LabelModule.Allowed](label: L with AnyRef)( - body: Position => SExpr + private[this] def labeledUnaryFunction[L: Profile.LabelModule.Allowed]( + label: L with AnyRef, + env: Env, + )( + body: (Position, Env) => SExpr ): SExpr = - unaryFunction(pos => withLabel(label, body(pos))) + unaryFunction(env)((positions, env) => withLabel(label, body(positions, env))) private[this] def topLevelFunction[SDefRef <: SDefinitionRef: LabelModule.Allowed]( - ref: SDefRef, - arity: Int, + ref: SDefRef )( - body: PartialFunction[List[Position], SExpr] + body: SExpr ): (SDefRef, SDefinition) = - ref -> - SDefinition( - unsafeClosureConvert( - withLabel(ref, SEAbs(arity, body(List.fill(arity)(nextPosition())))) - ) - ) + ref -> SDefinition(unsafeClosureConvert(withLabel(ref, body))) + + private val Position1 = Env.Empty.nextPosition + private val Env1 = Env.Empty.pushVar + private val Position2 = Env1.nextPosition + private val Env2 = Env1.pushVar + private val Position3 = Env2.nextPosition + private val Env3 = Env2.pushVar + + private[this] def topLevelFunction1[SDefRef <: SDefinitionRef: LabelModule.Allowed](ref: SDefRef)( + body: (Position, Env) => SExpr + ): (SDefRef, SDefinition) = + topLevelFunction(ref)(SEAbs(1, body(Position1, Env1))) + + private[this] def topLevelFunction2[SDefRef <: SDefinitionRef: LabelModule.Allowed](ref: SDefRef)( + body: (Position, Position, Env) => SExpr + ): (SDefRef, SDefinition) = + topLevelFunction(ref)(SEAbs(2, body(Position1, Position2, Env2))) + + private[this] def topLevelFunction3[SDefRef <: SDefinitionRef: LabelModule.Allowed](ref: SDefRef)( + body: (Position, Position, Position, Env) => SExpr + ): (SDefRef, SDefinition) = + topLevelFunction(ref)(SEAbs(3, body(Position1, Position2, Position3, Env3))) @throws[PackageNotFound] @throws[CompilationError] @@ -278,7 +308,7 @@ private[lf] final class Compiler( @throws[PackageNotFound] @throws[CompilationError] def unsafeCompile(expr: Expr): SExpr = - validate(compilationPipeline(compile(expr))) + validate(compilationPipeline(compile(Env.Empty, expr))) @throws[PackageNotFound] @throws[CompilationError] @@ -405,10 +435,10 @@ private[lf] final class Compiler( case SCPCons => 2 } - private[this] def compile(expr0: Expr): SExpr = + private[this] def compile(env: Env, expr0: Expr): SExpr = expr0 match { case EVar(name) => - lookupExprVar(name) + env.lookupExprVar(name) case EVal(ref) => SEVal(LfDefRef(ref)) case EBuiltin(bf) => @@ -418,13 +448,11 @@ private[lf] final class Compiler( case EPrimLit(lit) => compilePrimLit(lit) case EAbs(_, _, _) | ETyAbs(_, _) => - withEnv { _ => - compileAbss(expr0) - } + compileAbss(env, expr0) case EApp(_, _) | ETyApp(_, _) => - compileApps(expr0) + compileApps(env, expr0) case ERecCon(tApp, fields) => - compileERecCon(tApp, fields) + compileERecCon(env, tApp, fields) case ERecProj(tapp, field, record) => SBRecProj( tapp.tycon, @@ -432,38 +460,39 @@ private[lf] final class Compiler( NameOf.qualifiedNameOfCurrentFunc, interface.lookupRecordFieldInfo(tapp.tycon, field), ).index, - )( - compile(record) - ) + )(compile(env, record)) case erecupd: ERecUpd => - compileERecUpd(erecupd) + compileERecUpd(env, erecupd) case EStructCon(fields) => val fieldsInputOrder = Struct.assertFromSeq(fields.iterator.map(_._1).zipWithIndex.toSeq) SEApp( SEBuiltin(SBStructCon(fieldsInputOrder)), - mapToArray(fields) { case (_, e) => compile(e) }, + mapToArray(fields) { case (_, e) => compile(env, e) }, ) case structProj: EStructProj => structProj.fieldIndex match { - case None => SBStructProjByName(structProj.field)(compile(structProj.struct)) - case Some(index) => SBStructProj(index)(compile(structProj.struct)) + case None => SBStructProjByName(structProj.field)(compile(env, structProj.struct)) + case Some(index) => SBStructProj(index)(compile(env, structProj.struct)) } case structUpd: EStructUpd => structUpd.fieldIndex match { case None => - SBStructUpdByName(structUpd.field)(compile(structUpd.struct), compile(structUpd.update)) + SBStructUpdByName(structUpd.field)( + compile(env, structUpd.struct), + compile(env, structUpd.update), + ) case Some(index) => - SBStructUpd(index)(compile(structUpd.struct), compile(structUpd.update)) + SBStructUpd(index)(compile(env, structUpd.struct), compile(env, structUpd.update)) } case ECase(scrut, alts) => - compileECase(scrut, alts) + compileECase(env, scrut, alts) case ENil(_) => SEValue.EmptyList case ECons(_, front, tail) => // TODO(JM): Consider emitting SEValue(SList(...)) for // constant lists? - val args = (front.iterator.map(compile) ++ Seq(compile(tail))).toArray + val args = (front.iterator.map(compile(env, _)) ++ Seq(compile(env, tail))).toArray if (front.length == 1) { SEApp(SEBuiltin(SBCons), args) } else { @@ -472,7 +501,7 @@ private[lf] final class Compiler( case ENone(_) => SEValue.None case ESome(_, body) => - SBSome(compile(body)) + SBSome(compile(env, body)) case EEnumCon(tyCon, consName) => val rank = handleLookup( NameOf.qualifiedNameOfCurrentFunc, @@ -484,35 +513,35 @@ private[lf] final class Compiler( NameOf.qualifiedNameOfCurrentFunc, interface.lookupVariantConstructor(tapp.tycon, variant), ).rank - SBVariantCon(tapp.tycon, variant, rank)(compile(arg)) + SBVariantCon(tapp.tycon, variant, rank)(compile(env, arg)) case let: ELet => - withEnv(_ => compileELet(let)) + compileELet(env, let) case EUpdate(upd) => - compileEUpdate(upd) + compileEUpdate(env, upd) case ELocation(loc, EScenario(scen)) => - maybeSELocation(loc, compileScenario(scen, Some(loc))) + maybeSELocation(loc, compileScenario(env, scen, Some(loc))) case EScenario(scen) => - compileScenario(scen, None) + compileScenario(env, scen, None) case ELocation(loc, e) => - maybeSELocation(loc, compile(e)) + maybeSELocation(loc, compile(env, e)) case EToAny(ty, e) => - SBToAny(ty)(compile(e)) + SBToAny(ty)(compile(env, e)) case EFromAny(ty, e) => - SBFromAny(ty)(compile(e)) + SBFromAny(ty)(compile(env, e)) case ETypeRep(typ) => SEValue(STypeRep(typ)) case EToAnyException(ty, e) => - SBToAny(ty)(compile(e)) + SBToAny(ty)(compile(env, e)) case EFromAnyException(ty, e) => - SBFromAny(ty)(compile(e)) + SBFromAny(ty)(compile(env, e)) case EThrow(_, ty, e) => - SBThrow(SBToAny(ty)(compile(e))) + SBThrow(SBToAny(ty)(compile(env, e))) case EToInterface(iface @ _, tpl @ _, e) => - compile(e) // interfaces have the same representation as underlying template + compile(env, e) // interfaces have the same representation as underlying template case EFromInterface(iface @ _, tpl, e) => - SBFromInterface(tpl)(compile(e)) + SBFromInterface(tpl)(compile(env, e)) case ECallInterface(iface, methodName, e) => - SBCallInterface(iface, methodName)(compile(e)) + SBCallInterface(iface, methodName)(compile(env, e)) case EExperimental(name, _) => SBExperimental(name) @@ -684,18 +713,20 @@ private[lf] final class Compiler( private def noArgs = new util.ArrayList[SValue](0) @inline - private[this] def compileERecCon(tApp: TypeConApp, fields: ImmArray[(FieldName, Expr)]): SExpr = + private[this] def compileERecCon( + env: Env, + tApp: TypeConApp, + fields: ImmArray[(FieldName, Expr)], + ): SExpr = if (fields.isEmpty) SEValue(SRecord(tApp.tycon, ImmArray.Empty, noArgs)) - else { + else SEApp( SEBuiltin(SBRecCon(tApp.tycon, fields.map(_._1))), - fields.iterator.map(f => compile(f._2)).toArray, + fields.iterator.map(f => compile(env, f._2)).toArray, ) - } - @inline - private[this] def compileERecUpd(erecupd: ERecUpd): SExpr = { + private[this] def compileERecUpd(env: Env, erecupd: ERecUpd): SExpr = { val tapp = erecupd.tycon val (record, fields, updates) = collectRecUpds(erecupd) if (fields.length == 1) { @@ -703,7 +734,7 @@ private[lf] final class Compiler( NameOf.qualifiedNameOfCurrentFunc, interface.lookupRecordFieldInfo(tapp.tycon, fields.head), ).index - SBRecUpd(tapp.tycon, index)(compile(record), compile(updates.head)) + SBRecUpd(tapp.tycon, index)(compile(env, record), compile(env, updates.head)) } else { val indices = fields.map(name => @@ -712,14 +743,13 @@ private[lf] final class Compiler( interface.lookupRecordFieldInfo(tapp.tycon, name), ).index ) - SBRecUpdMulti(tapp.tycon, indices.to(ImmArray))((record :: updates).map(compile): _*) + SBRecUpdMulti(tapp.tycon, indices.to(ImmArray))((record :: updates).map(compile(env, _)): _*) } } - @inline - private[this] def compileECase(scrut: Expr, alts: ImmArray[CaseAlt]): SExpr = + private[this] def compileECase(env: Env, scrut: Expr, alts: ImmArray[CaseAlt]): SExpr = SECase( - compile(scrut), + compile(env, scrut), mapToArray(alts) { case CaseAlt(pat, expr) => pat match { case CPVariant(tycon, variant, binder) => @@ -727,38 +757,32 @@ private[lf] final class Compiler( NameOf.qualifiedNameOfCurrentFunc, interface.lookupVariantConstructor(tycon, variant), ).rank - withBinders(binder) { _ => - SCaseAlt(SCPVariant(tycon, variant, rank), compile(expr)) - } + SCaseAlt(SCPVariant(tycon, variant, rank), compile(env.pushExprVar(binder), expr)) case CPEnum(tycon, constructor) => val rank = handleLookup( NameOf.qualifiedNameOfCurrentFunc, interface.lookupEnumConstructor(tycon, constructor), ) - SCaseAlt(SCPEnum(tycon, constructor, rank), compile(expr)) + SCaseAlt(SCPEnum(tycon, constructor, rank), compile(env, expr)) case CPNil => - SCaseAlt(SCPNil, compile(expr)) + SCaseAlt(SCPNil, compile(env, expr)) case CPCons(head, tail) => - withBinders(head, tail) { _ => - SCaseAlt(SCPCons, compile(expr)) - } + SCaseAlt(SCPCons, compile(env.pushExprVar(head).pushExprVar(tail), expr)) case CPPrimCon(pc) => - SCaseAlt(SCPPrimCon(pc), compile(expr)) + SCaseAlt(SCPPrimCon(pc), compile(env, expr)) case CPNone => - SCaseAlt(SCPNone, compile(expr)) + SCaseAlt(SCPNone, compile(env, expr)) case CPSome(body) => - withBinders(body) { _ => - SCaseAlt(SCPSome, compile(expr)) - } + SCaseAlt(SCPSome, compile(env.pushExprVar(body), expr)) case CPDefault => - SCaseAlt(SCPDefault, compile(expr)) + SCaseAlt(SCPDefault, compile(env, expr)) } }, ) @@ -766,18 +790,18 @@ private[lf] final class Compiler( // Compile nested lets using constant stack. @tailrec private[this] def compileELet( + env0: Env, eLet0: ELet, bounds0: List[SExpr] = List.empty, ): SELet = { val binding = eLet0.binding - val bounds = withOptLabel(binding.binder, compile(binding.bound)) :: bounds0 - val boundPos = nextPosition() - binding.binder.foreach(addExprVar(_, boundPos)) + val bounds = withOptLabel(binding.binder, compile(env0, binding.bound)) :: bounds0 + val env1 = env0.pushExprVar(binding.binder) eLet0.body match { case eLet1: ELet => - compileELet(eLet1, bounds) + compileELet(env1, eLet1, bounds) case body0 => - compile(body0) match { + compile(env1, body0) match { case SELet(bounds1, body1) => SELet(bounds.foldLeft(bounds1)((acc, b) => b :: acc), body1) case otherwise => @@ -787,106 +811,106 @@ private[lf] final class Compiler( } @inline - private[this] def compileEUpdate(update: Update): SExpr = + private[this] def compileEUpdate(env: Env, update: Update): SExpr = update match { case UpdatePure(_, e) => - compilePure(e) + compilePure(env, e) case UpdateBlock(bindings, body) => - compileBlock(bindings, body) + compileBlock(env, bindings, body) case UpdateFetch(tmplId, coidE) => - FetchDefRef(tmplId)(compile(coidE)) + FetchDefRef(tmplId)(compile(env, coidE)) case UpdateFetchInterface(ifaceId, coidE) => - FetchDefRef(ifaceId)(compile(coidE)) + FetchDefRef(ifaceId)(compile(env, coidE)) case UpdateEmbedExpr(_, e) => - compileEmbedExpr(e) + compileEmbedExpr(env, e) case UpdateCreate(tmplId, arg) => - CreateDefRef(tmplId)(compile(arg)) + CreateDefRef(tmplId)(compile(env, arg)) case UpdateExercise(tmplId, chId, cidE, argE) => - ChoiceDefRef(tmplId, chId)(compile(cidE), compile(argE)) + ChoiceDefRef(tmplId, chId)(compile(env, cidE), compile(env, argE)) case UpdateExerciseInterface(ifaceId, chId, cidE, argE) => - ChoiceDefRef(ifaceId, chId)(compile(cidE), compile(argE)) + ChoiceDefRef(ifaceId, chId)(compile(env, cidE), compile(env, argE)) case UpdateExerciseByKey(tmplId, chId, keyE, argE) => - ChoiceByKeyDefRef(tmplId, chId)(compile(keyE), compile(argE)) + ChoiceByKeyDefRef(tmplId, chId)(compile(env, keyE), compile(env, argE)) case UpdateGetTime => SEGetTime case UpdateLookupByKey(RetrieveByKey(templateId, key)) => - LookupByKeyDefRef(templateId)(compile(key)) + LookupByKeyDefRef(templateId)(compile(env, key)) case UpdateFetchByKey(RetrieveByKey(templateId, key)) => - FetchByKeyDefRef(templateId)(compile(key)) + FetchByKeyDefRef(templateId)(compile(env, key)) case UpdateTryCatch(_, body, binder, handler) => - unaryFunction { tokenPos => + unaryFunction(env) { (tokenPos, env0) => SETryCatch( - app(compile(body), svar(tokenPos)), - withEnv { _ => - val binderPos = nextPosition() - addExprVar(binder, binderPos) - SBTryHandler(compile(handler), svar(binderPos), svar(tokenPos)) + app(compile(env0, body), env0.toSEVar(tokenPos)), { + val env1 = env0.pushExprVar(binder) + SBTryHandler( + compile(env1, handler), + env1.lookupExprVar(binder), + env1.toSEVar(tokenPos), + ) }, ) } } @tailrec - private[this] def compileAbss(expr0: Expr, arity: Int = 0): SExpr = + private[this] def compileAbss(env: Env, expr0: Expr, arity: Int = 0): SExpr = expr0 match { case EAbs((binder, typ @ _), body, ref @ _) => - addExprVar(binder, nextPosition()) - compileAbss(body, arity + 1) + compileAbss(env.pushExprVar(binder), body, arity + 1) case ETyAbs((binder, KNat), body) => - addTypeVar(binder, nextPosition()) - compileAbss(body, arity + 1) + compileAbss(env.pushTypeVar(binder), body, arity + 1) case ETyAbs((binder, _), body) => - hideTypeVar(binder) - compileAbss(body, arity) + compileAbss(env.hideTypeVar(binder), body, arity) case _ if arity == 0 => - compile(expr0) + compile(env, expr0) case _ => - withLabel(AnonymousClosure, SEAbs(arity, compile(expr0))) + withLabel(AnonymousClosure, SEAbs(arity, compile(env, expr0))) } @tailrec - private[this] def compileApps(expr0: Expr, args: List[SExpr] = List.empty): SExpr = + private[this] def compileApps(env: Env, expr0: Expr, args: List[SExpr] = List.empty): SExpr = expr0 match { case EApp(fun, arg) => - compileApps(fun, compile(arg) :: args) + compileApps(env, fun, compile(env, arg) :: args) case ETyApp(fun, arg) => - compileApps(fun, translateType(arg).fold(args)(_ :: args)) + compileApps(env, fun, translateType(env, arg).fold(args)(_ :: args)) case _ if args.isEmpty => - compile(expr0) + compile(env, expr0) case _ => - SEApp(compile(expr0), args.toArray) + SEApp(compile(env, expr0), args.toArray) } - private[this] def translateType(typ: Type): Option[SExpr] = + private[this] def translateType(env: Env, typ: Type): Option[SExpr] = typ match { case TNat(n) => SENat(n) - case TVar(name) => lookupTypeVar(name) + case TVar(name) => env.lookupTypeVar(name) case _ => None } - private[this] def compileScenario(scen: Scenario, optLoc: Option[Location]): SExpr = + private[this] def compileScenario(env: Env, scen: Scenario, optLoc: Option[Location]): SExpr = scen match { case ScenarioPure(_, e) => - compilePure(e) + compilePure(env, e) case ScenarioBlock(bindings, body) => - compileBlock(bindings, body) + compileBlock(env, bindings, body) case ScenarioCommit(partyE, updateE, _retType @ _) => - compileCommit(partyE, updateE, optLoc, mustFail = false) + compileCommit(env, partyE, updateE, optLoc, mustFail = false) case ScenarioMustFailAt(partyE, updateE, _retType @ _) => - compileCommit(partyE, updateE, optLoc, mustFail = true) + compileCommit(env, partyE, updateE, optLoc, mustFail = true) case ScenarioGetTime => SEGetTime case ScenarioGetParty(e) => - compileGetParty(e) + compileGetParty(env, e) case ScenarioPass(relTime) => - compilePass(relTime) + compilePass(env, relTime) case ScenarioEmbedExpr(_, e) => - compileEmbedExpr(e) + compileEmbedExpr(env, e) } @inline private[this] def compileCommit( + env: Env, partyE: Expr, updateE: Expr, optLoc: Option[Location], @@ -895,49 +919,45 @@ private[lf] final class Compiler( // let party = // update = // in $submit(mustFail)(party, update) - withEnv { _ => - let(compile(partyE)) { partyLoc => - let(compile(updateE)) { updateLoc => - SBSSubmit(optLoc, mustFail)(svar(partyLoc), svar(updateLoc)) - } + let(env, compile(env, partyE)) { (partyLoc, env) => + let(env, compile(env, updateE)) { (updateLoc, env) => + SBSSubmit(optLoc, mustFail)(env.toSEVar(partyLoc), env.toSEVar(updateLoc)) } } @inline - private[this] def compileGetParty(expr: Expr): SExpr = - labeledUnaryFunction(Profile.GetPartyLabel) { tokenPos => - SBSGetParty(compile(expr), svar(tokenPos)) + private[this] def compileGetParty(env: Env, expr: Expr): SExpr = + labeledUnaryFunction(Profile.GetPartyLabel, env) { (tokenPos, env) => + SBSGetParty(compile(env, expr), env.toSEVar(tokenPos)) } @inline - private[this] def compilePass(time: Expr): SExpr = - labeledUnaryFunction(Profile.PassLabel) { tokenPos => - SBSPass(compile(time), svar(tokenPos)) + private[this] def compilePass(env: Env, time: Expr): SExpr = + labeledUnaryFunction(Profile.PassLabel, env) { (tokenPos, env) => + SBSPass(compile(env, time), env.toSEVar(tokenPos)) } @inline - private[this] def compileEmbedExpr(expr: Expr): SExpr = + private[this] def compileEmbedExpr(env: Env, expr: Expr): SExpr = // EmbedExpr's get wrapped into an extra layer of abstraction // to delay evaluation. // e.g. // embed (error "foo") => \token -> error "foo" - unaryFunction { tokenPos => - app(compile(expr), svar(tokenPos)) + unaryFunction(env) { (tokenPos, env) => + app(compile(env, expr), env.toSEVar(tokenPos)) } - private[this] def compilePure(body: Expr): SExpr = + private[this] def compilePure(env: Env, body: Expr): SExpr = // pure // => // ((\x token -> x) ) - withEnv { _ => - let(compile(body)) { bodyPos => - unaryFunction { tokenPos => - SBSPure(svar(bodyPos), svar(tokenPos)) - } + let(env, compile(env, body)) { (bodyPos, env) => + unaryFunction(env) { (tokenPos, env) => + SBSPure(env.toSEVar(bodyPos), env.toSEVar(tokenPos)) } } - private[this] def compileBlock(bindings: ImmArray[Binding], body: Expr): SExpr = + private[this] def compileBlock(env: Env, bindings: ImmArray[Binding], body: Expr): SExpr = // do // x <- f // y <- g x @@ -948,24 +968,22 @@ private[lf] final class Compiler( // let x = f' token // y = g x token // in z x y token - withEnv { _ => - let(compile(bindings.head.bound)) { firstPos => - unaryFunction { tokenPos => - let(app(svar(firstPos), svar(tokenPos))) { firstBoundPos => - bindings.head.binder.foreach(addExprVar(_, firstBoundPos)) - - def loop(list: List[Binding]): SExpr = list match { - case Binding(binder, _, bound) :: tail => - let(app(compile(bound), svar(tokenPos))) { boundPos => - binder.foreach(addExprVar(_, boundPos)) - loop(tail) - } - case Nil => - app(compile(body), svar(tokenPos)) - } - - loop(bindings.tail.toList) + let(env, compile(env, bindings.head.bound)) { (firstPos, env) => + unaryFunction(env) { (tokenPos, env) => + let(env, app(env.toSEVar(firstPos), env.toSEVar(tokenPos))) { (firstBoundPos, _env) => + val env = bindings.head.binder.fold(_env)(_env.bindExprVar(_, firstBoundPos)) + + def loop(env: Env, list: List[Binding]): SExpr = list match { + case Binding(binder, _, bound) :: tail => + let(env, app(compile(env, bound), env.toSEVar(tokenPos))) { (boundPos, _env) => + val env = binder.fold(_env)(_env.bindExprVar(_, boundPos)) + loop(env, tail) + } + case Nil => + app(compile(env, body), env.toSEVar(tokenPos)) } + + loop(env, bindings.tail.toList) } } } @@ -973,19 +991,27 @@ private[lf] final class Compiler( private[this] val KeyWithMaintainersStruct = SBStructCon(Struct.assertFromSeq(List(keyFieldName, maintainersFieldName).zipWithIndex)) - private[this] def encodeKeyWithMaintainers(keyPos: Position, tmplKey: TemplateKey): SExpr = - KeyWithMaintainersStruct(svar(keyPos), app(compile(tmplKey.maintainers), svar(keyPos))) + private[this] def encodeKeyWithMaintainers( + env: Env, + keyPos: Position, + tmplKey: TemplateKey, + ): SExpr = + KeyWithMaintainersStruct( + env.toSEVar(keyPos), + app(compile(env, tmplKey.maintainers), env.toSEVar(keyPos)), + ) - private[this] def compileKeyWithMaintainers(maybeTmplKey: Option[TemplateKey]): SExpr = - withEnv { _ => - maybeTmplKey match { - case None => SEValue.None - case Some(tmplKey) => - let(compile(tmplKey.body))(keyPos => SBSome(encodeKeyWithMaintainers(keyPos, tmplKey))) - } + private[this] def compileKeyWithMaintainers(env: Env, maybeTmplKey: Option[TemplateKey]): SExpr = + maybeTmplKey match { + case None => SEValue.None + case Some(tmplKey) => + let(env, compile(env, tmplKey.body)) { (keyPos, env) => + SBSome(encodeKeyWithMaintainers(env, keyPos, tmplKey)) + } } private[this] def compileChoiceBody( + env: Env, tmplId: TypeConName, tmpl: Template, choice: TemplateChoice, @@ -994,34 +1020,38 @@ private[lf] final class Compiler( cidPos: Position, mbKey: Option[Position], // defined for byKey operation tokenPos: Position, - ) = withEnv { _ => + ) = let( + env, SBUFetch( tmplId - )(svar(cidPos), mbKey.fold(SEValue.None: SExpr)(pos => SBSome(svar(pos)))) - ) { tmplArgPos => - addExprVar(tmpl.param, tmplArgPos) - addExprVar(choice.argBinder._1, choiceArgPos) + )(env.toSEVar(cidPos), mbKey.fold(SEValue.None: SExpr)(pos => SBSome(env.toSEVar(pos)))), + ) { (tmplArgPos, _env) => + val env = + _env.bindExprVar(tmpl.param, tmplArgPos).bindExprVar(choice.argBinder._1, choiceArgPos) let( + env, SBUBeginExercise(tmplId, choice.name, choice.consuming, byKey = mbKey.isDefined)( - svar(choiceArgPos), - svar(cidPos), - compile(choice.controllers), + env.toSEVar(choiceArgPos), + env.toSEVar(cidPos), + compile(env, choice.controllers), choice.choiceObservers match { - case Some(observers) => compile(observers) + case Some(observers) => compile(env, observers) case None => SEValue.EmptyList }, + ), + ) { (_, _env) => + val env = _env.bindExprVar(choice.selfBinder, cidPos) + SEScopeExercise( + app(compile(env, choice.update), env.toSEVar(tokenPos)) ) - ) { _ => - addExprVar(choice.selfBinder, cidPos) - SEScopeExercise(app(compile(choice.update), svar(tokenPos))) } } - } // TODO https://github.com/digital-asset/daml/issues/10810: // Try to factorise this with compileChoiceBody above. private[this] def compileFixedChoiceBody( + env: Env, ifaceId: TypeConName, param: ExprVarName, choice: TemplateChoice, @@ -1029,27 +1059,26 @@ private[lf] final class Compiler( choiceArgPos: Position, cidPos: Position, tokenPos: Position, - ) = withEnv { _ => - let(SBUFetchInterface(ifaceId)(svar(cidPos))) { payloadPos => - addExprVar(param, payloadPos) - addExprVar(choice.argBinder._1, choiceArgPos) + ) = + let(env, SBUFetchInterface(ifaceId)(env.toSEVar(cidPos))) { (payloadPos, _env) => + val env = _env.bindExprVar(param, payloadPos).bindExprVar(choice.argBinder._1, choiceArgPos) let( + env, SBResolveSBUBeginExercise(choice.name, choice.consuming, byKey = false)( - svar(payloadPos), - svar(choiceArgPos), - svar(cidPos), - compile(choice.controllers), + env.toSEVar(payloadPos), + env.toSEVar(choiceArgPos), + env.toSEVar(cidPos), + compile(env, choice.controllers), choice.choiceObservers match { - case Some(observers) => compile(observers) + case Some(observers) => compile(env, observers) case None => SEValue.EmptyList }, - ) - ) { _ => - addExprVar(choice.selfBinder, cidPos) - SEScopeExercise(app(compile(choice.update), svar(tokenPos))) + ), + ) { (_, _env) => + val env = _env.bindExprVar(choice.selfBinder, cidPos) + SEScopeExercise(app(compile(env, choice.update), env.toSEVar(tokenPos))) } } - } // TODO https://github.com/digital-asset/daml/issues/10810: // Here we fetch twice, once by interface Id once by template Id. Try to bypass the second fetch. @@ -1057,16 +1086,15 @@ private[lf] final class Compiler( ifaceId: TypeConName, choice: InterfaceChoice, ): (SDefinitionRef, SDefinition) = - topLevelFunction(ChoiceDefRef(ifaceId, choice.name), 3) { - case List(cidPos, choiceArgPos, tokenPos) => - let(SBUFetchInterface(ifaceId)(svar(cidPos))) { payloadPos => - SBResolveVirtualChoice(choice.name)( - svar(payloadPos), - svar(cidPos), - svar(choiceArgPos), - svar(tokenPos), - ) - } + topLevelFunction3(ChoiceDefRef(ifaceId, choice.name)) { (cidPos, choiceArgPos, tokenPos, env) => + let(env, SBUFetchInterface(ifaceId)(env.toSEVar(cidPos))) { (payloadPos, env) => + SBResolveVirtualChoice(choice.name)( + env.toSEVar(payloadPos), + env.toSEVar(cidPos), + env.toSEVar(choiceArgPos), + env.toSEVar(tokenPos), + ) + } } private[this] def compileFixedChoice( @@ -1074,13 +1102,12 @@ private[lf] final class Compiler( param: ExprVarName, choice: TemplateChoice, ): (SDefinitionRef, SDefinition) = - topLevelFunction(ChoiceDefRef(ifaceId, choice.name), 3) { - case List(cidPos, choiceArgPos, tokenPos) => - compileFixedChoiceBody(ifaceId, param, choice)( - choiceArgPos, - cidPos, - tokenPos, - ) + topLevelFunction3(ChoiceDefRef(ifaceId, choice.name)) { (cidPos, choiceArgPos, tokenPos, env) => + compileFixedChoiceBody(env, ifaceId, param, choice)( + choiceArgPos, + cidPos, + tokenPos, + ) } private[this] def compileChoice( @@ -1095,14 +1122,13 @@ private[lf] final class Compiler( // = [update] // _ = $endExercise[tmplId] // in - topLevelFunction(ChoiceDefRef(tmplId, choice.name), 3) { - case List(cidPos, choiceArgPos, tokenPos) => - compileChoiceBody(tmplId, tmpl, choice)( - choiceArgPos, - cidPos, - None, - tokenPos, - ) + topLevelFunction3(ChoiceDefRef(tmplId, choice.name)) { (cidPos, choiceArgPos, tokenPos, env) => + compileChoiceBody(env, tmplId, tmpl, choice)( + choiceArgPos, + cidPos, + None, + tokenPos, + ) } /** Compile a choice into a top-level function for exercising that choice */ @@ -1121,11 +1147,11 @@ private[lf] final class Compiler( // = // _ = $endExercise[tmplId] // in - topLevelFunction(ChoiceByKeyDefRef(tmplId, choice.name), 3) { - case List(keyPos, choiceArgPos, tokenPos) => - let(encodeKeyWithMaintainers(keyPos, tmplKey)) { keyWithMPos => - let(SBUFetchKey(tmplId)(svar(keyWithMPos))) { cidPos => - compileChoiceBody(tmplId, tmpl, choice)( + topLevelFunction3(ChoiceByKeyDefRef(tmplId, choice.name)) { + (keyPos, choiceArgPos, tokenPos, env) => + let(env, encodeKeyWithMaintainers(env, keyPos, tmplKey)) { (keyWithMPos, env) => + let(env, SBUFetchKey(tmplId)(env.toSEVar(keyWithMPos))) { (cidPos, env) => + compileChoiceBody(env, tmplId, tmpl, choice)( choiceArgPos, cidPos, Some(keyWithMPos), @@ -1142,19 +1168,6 @@ private[lf] final class Compiler( case _ => expr } - private[this] def withEnv[A](f: Unit => A): A = { - val oldEnv = env - val x = f(()) - env = oldEnv - x - } - - private[this] def withBinders[A](binders: ExprVarName*)(f: Unit => A): A = - withEnv { _ => - binders.foreach(addExprVar(_, nextPosition())) - f(()) - } - /** Convert abstractions in a speedy expression into * explicit closure creations. * This step computes the free variables in an abstraction @@ -1439,25 +1452,23 @@ private[lf] final class Compiler( } @nowarn("msg=parameter value tokenPos in method compileFetchBody is never used") - private[this] def compileFetchBody(tmplId: Identifier, tmpl: Template)( + private[this] def compileFetchBody(env: Env, tmplId: Identifier, tmpl: Template)( cidPos: Position, mbKey: Option[Position], //defined for byKey operation tokenPos: Position, ) = - withEnv { _ => + let( + env, + SBUFetch( + tmplId + )(env.toSEVar(cidPos), mbKey.fold(SEValue.None: SExpr)(pos => SBSome(env.toSEVar(pos)))), + ) { (tmplArgPos, _env) => + val env = _env.bindExprVar(tmpl.param, tmplArgPos) let( - SBUFetch( - tmplId - )(svar(cidPos), mbKey.fold(SEValue.None: SExpr)(pos => SBSome(svar(pos)))) - ) { tmplArgPos => - addExprVar(tmpl.param, tmplArgPos) - let( - SBUInsertFetchNode(tmplId, byKey = mbKey.isDefined)( - svar(cidPos) - ) - ) { _ => - svar(tmplArgPos) - } + env, + SBUInsertFetchNode(tmplId, byKey = mbKey.isDefined)(env.toSEVar(cidPos)), + ) { (_, env) => + env.toSEVar(tmplArgPos) } } @@ -1470,8 +1481,8 @@ private[lf] final class Compiler( // let = $fetch(tmplId) // _ = $insertFetch(tmplId, false) coid [tmpl.signatories] [tmpl.observers] [tmpl.key] // in - topLevelFunction(FetchDefRef(tmplId), 2) { case List(cidPos, tokenPos) => - compileFetchBody(tmplId, tmpl)(cidPos, None, tokenPos) + topLevelFunction2(FetchDefRef(tmplId)) { (cidPos, tokenPos, env) => + compileFetchBody(env, tmplId, tmpl)(cidPos, None, tokenPos) } // TODO https://github.com/digital-asset/daml/issues/10810: @@ -1479,9 +1490,13 @@ private[lf] final class Compiler( private[this] def compileFetchInterface( ifaceId: Identifier ): (SDefinitionRef, SDefinition) = - topLevelFunction(FetchDefRef(ifaceId), 2) { case List(cidPos, tokenPos) => - let(SBUFetchInterface(ifaceId)(svar(cidPos))) { payloadPos => - SBResolveVirtualFetch(svar(payloadPos), svar(cidPos), svar(tokenPos)) + topLevelFunction2(FetchDefRef(ifaceId)) { (cidPos, tokenPos, env) => + let(env, SBUFetchInterface(ifaceId)(env.toSEVar(cidPos))) { (payloadPos, env) => + SBResolveVirtualFetch( + env.toSEVar(payloadPos), + env.toSEVar(cidPos), + env.toSEVar(tokenPos), + ) } } @@ -1489,27 +1504,24 @@ private[lf] final class Compiler( tmplId: Identifier, tmpl: Template, ): (SDefinitionRef, SDefinition) = - topLevelFunction(KeyDefRef(tmplId), 1) { case List(tmplArgPos) => - addExprVar(tmpl.param, tmplArgPos) - compileKeyWithMaintainers(tmpl.key) + topLevelFunction1(KeyDefRef(tmplId)) { (tmplArgPos, env) => + compileKeyWithMaintainers(env.bindExprVar(tmpl.param, tmplArgPos), tmpl.key) } private[this] def compileSignatories( tmplId: Identifier, tmpl: Template, ): (SDefinitionRef, SDefinition) = - topLevelFunction(SignatoriesDefRef(tmplId), 1) { case List(tmplArgPos) => - addExprVar(tmpl.param, tmplArgPos) - compile(tmpl.signatories) + topLevelFunction1(SignatoriesDefRef(tmplId)) { (tmplArgPos, env) => + compile(env.bindExprVar(tmpl.param, tmplArgPos), tmpl.signatories) } private[this] def compileObservers( tmplId: Identifier, tmpl: Template, ): (SDefinitionRef, SDefinition) = - topLevelFunction(ObserversDefRef(tmplId), 1) { case List(tmplArgPos) => - addExprVar(tmpl.param, tmplArgPos) - compile(tmpl.observers) + topLevelFunction1(ObserversDefRef(tmplId)) { (tmplArgPos, env) => + compile(env.bindExprVar(tmpl.param, tmplArgPos), tmpl.observers) } // Turn a template value into an interface value. Since interfaces have a @@ -1520,9 +1532,8 @@ private[lf] final class Compiler( tmplId: Identifier, ifaceId: Identifier, ): (SDefinitionRef, SDefinition) = - topLevelFunction(ImplementsDefRef(tmplId, ifaceId), 1) { case List(tmplPos) => - svar(tmplPos) - } + ImplementsDefRef(tmplId, ifaceId) -> + SDefinition(flattenToAnf(unsafeClosureConvert(SEAbs.identity))) // Compile the implementation of an interface method. private[this] def compileImplementsMethod( @@ -1542,19 +1553,20 @@ private[lf] final class Compiler( // CreateDefRef(tmplId) = \ -> // let _ = $checkPreconf(tmplId)( [tmpl.precond] // in $create [tmpl.agreementText] [tmpl.signatories] [tmpl.observers] [tmpl.key] - topLevelFunction(CreateDefRef(tmplId), 2) { case List(tmplArgPos, tokenPos @ _) => - addExprVar(tmpl.param, tmplArgPos) + topLevelFunction2(CreateDefRef(tmplId)) { (tmplArgPos, _, _env) => + val env = _env.bindExprVar(tmpl.param, tmplArgPos) // We check precondition in a separated builtin to prevent // further evaluation of agreement, signatories, observers and key // in case of failed precondition. - let(SBCheckPrecond(tmplId)(svar(tmplArgPos), compile(tmpl.precond))) { _ => - SBUCreate(tmplId)( - svar(tmplArgPos), - compile(tmpl.agreementText), - compile(tmpl.signatories), - compile(tmpl.observers), - compileKeyWithMaintainers(tmpl.key), - ) + let(env, SBCheckPrecond(tmplId)(env.toSEVar(tmplArgPos), compile(env, tmpl.precond))) { + (_, env) => + SBUCreate(tmplId)( + env.toSEVar(tmplArgPos), + compile(env, tmpl.agreementText), + compile(env, tmpl.signatories), + compile(env, tmpl.observers), + compileKeyWithMaintainers(env, tmpl.key), + ) } } @@ -1564,10 +1576,15 @@ private[lf] final class Compiler( choiceId: ChoiceName, choiceArg: SValue, ): SExpr = - labeledUnaryFunction(Profile.CreateAndExerciseLabel(tmplId, choiceId)) { tokenPos => - let(CreateDefRef(tmplId)(SEValue(createArg), svar(tokenPos))) { cidPos => - ChoiceDefRef(tmplId, choiceId)(svar(cidPos), SEValue(choiceArg), svar(tokenPos)) - } + labeledUnaryFunction(Profile.CreateAndExerciseLabel(tmplId, choiceId), Env.Empty) { + (tokenPos, env) => + let(env, CreateDefRef(tmplId)(SEValue(createArg), env.toSEVar(tokenPos))) { (cidPos, env) => + ChoiceDefRef(tmplId, choiceId)( + env.toSEVar(cidPos), + SEValue(choiceArg), + env.toSEVar(tokenPos), + ) + } } private[this] def compileLookupByKey( @@ -1580,11 +1597,14 @@ private[lf] final class Compiler( // = $lookupKey(tmplId) // _ = $insertLookup(tmplId> // in - topLevelFunction(LookupByKeyDefRef(tmplId), 2) { case List(keyPos, tokenPos @ _) => - let(encodeKeyWithMaintainers(keyPos, tmplKey)) { keyWithMPos => - let(SBULookupKey(tmplId)(svar(keyWithMPos))) { maybeCidPos => - let(SBUInsertLookupNode(tmplId)(svar(keyWithMPos), svar(maybeCidPos))) { _ => - svar(maybeCidPos) + topLevelFunction2(LookupByKeyDefRef(tmplId)) { (keyPos, _, env) => + let(env, encodeKeyWithMaintainers(env, keyPos, tmplKey)) { (keyWithMPos, env) => + let(env, SBULookupKey(tmplId)(env.toSEVar(keyWithMPos))) { (maybeCidPos, env) => + let( + env, + SBUInsertLookupNode(tmplId)(env.toSEVar(keyWithMPos), env.toSEVar(maybeCidPos)), + ) { (_, env) => + env.toSEVar(maybeCidPos) } } } @@ -1606,11 +1626,12 @@ private[lf] final class Compiler( // = $fetch(tmplId) // _ = $insertFetch (Some ) // in { contractId: ContractId Foo, contract: Foo } - topLevelFunction(FetchByKeyDefRef(tmplId), 2) { case List(keyPos, tokenPos) => - let(encodeKeyWithMaintainers(keyPos, tmplKey)) { keyWithMPos => - let(SBUFetchKey(tmplId)(svar(keyWithMPos))) { cidPos => - let(compileFetchBody(tmplId, tmpl)(cidPos, Some(keyWithMPos), tokenPos)) { contractPos => - FetchByKeyResult(svar(cidPos), svar(contractPos)) + topLevelFunction2(FetchByKeyDefRef(tmplId)) { (keyPos, tokenPos, env) => + let(env, encodeKeyWithMaintainers(env, keyPos, tmplKey)) { (keyWithMPos, env) => + let(env, SBUFetchKey(tmplId)(env.toSEVar(keyWithMPos))) { (cidPos, env) => + let(env, compileFetchBody(env, tmplId, tmpl)(cidPos, Some(keyWithMPos), tokenPos)) { + (contractPos, env) => + FetchByKeyResult(env.toSEVar(cidPos), env.toSEVar(contractPos)) } } } @@ -1638,25 +1659,23 @@ private[lf] final class Compiler( LookupByKeyDefRef(templateId)(SEValue(contractKey)) } - private val SEUpdatePureUnit = unaryFunction(_ => SEValue.Unit) + private val SEUpdatePureUnit = unaryFunction(Env.Empty)((_, _) => SEValue.Unit) private[this] val handleEverything: SExpr = SBSome(SEUpdatePureUnit) - private[this] def catchEverything(e: SExpr): SExpr = { - unaryFunction { tokenPos => + private[this] def catchEverything(e: SExpr): SExpr = + unaryFunction(Env.Empty) { (tokenPos, env0) => SETryCatch( - app(e, svar(tokenPos)), - withEnv { _ => - val binderPos = nextPosition() - SBTryHandler(handleEverything, svar(binderPos), svar(tokenPos)) + app(e, env0.toSEVar(tokenPos)), { + val binderPos = env0.nextPosition + val env1 = env0.pushVar + SBTryHandler(handleEverything, env1.toSEVar(binderPos), env1.toSEVar(tokenPos)) }, ) } - } - private[this] def compileCommandForReinterpretation(cmd: Command): SExpr = { + private[this] def compileCommandForReinterpretation(cmd: Command): SExpr = catchEverything(compileCommand(cmd)) - } private[this] def compileCommands(bindings: ImmArray[Command]): SExpr = // commands are compile similarly as update block @@ -1665,15 +1684,16 @@ private[lf] final class Compiler( case Nil => SEUpdatePureUnit case first :: rest => - let(compileCommand(first)) { firstPos => - unaryFunction { tokenPos => - let(app(svar(firstPos), svar(tokenPos))) { _ => + let(Env.Empty, compileCommand(first)) { (firstPos, env) => + unaryFunction(env) { (tokenPos, env) => + let(env, app(env.toSEVar(firstPos), env.toSEVar(tokenPos))) { (_, _env) => // we cannot process `rest` recursively without exposing ourselves to stack overflow. - val exprs = rest.iterator.map { cmd => - val expr = app(compileCommand(cmd), svar(tokenPos)) - discard(nextPosition()) + var env = _env + val exprs = rest.map { cmd => + val expr = app(compileCommand(cmd), env.toSEVar(tokenPos)) + env = env.pushVar expr - }.toList + } SELet(exprs, SEValue.Unit) } }