Skip to content

Commit

Permalink
Suport multi-party readAs in triggers (#11299)
Browse files Browse the repository at this point in the history
* Suport multi-party readAs in triggers

fixes #7640

This does not yet include the trigger service. We’ll tackle that separately.

changelog_begin

- [Daml Triggers] Triggers now support readAs parties. They can be
  specified via `--ledger-readas a,b,c`. As part of this change
  ``testRule`` gained an extra argument to specify the `readAs`
  parties. If you previously used

  ```
  testRule trigger party acsBuilder commandsInFlight s
  ```

  you now need to use

  ```
  testRule trigger party [] acsBuilder commandsInFlight s
  ```

changelog_end

* Update triggers/tests/src/test/scala/com/digitalasset/daml/lf/engine/trigger/test/AbstractFuncTests.scala

Co-authored-by: Andreas Herrmann <[email protected]>

Co-authored-by: Andreas Herrmann <[email protected]>
  • Loading branch information
cocreature and aherrmann-da authored Oct 20, 2021
1 parent 76eb165 commit 07ad3e0
Show file tree
Hide file tree
Showing 21 changed files with 250 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ object Util {
object TFun extends ((Type, Type) => Type) {
def apply(targ: Type, tres: Type) =
TApp(TApp(TBuiltin(BTArrow), targ), tres)
def unapply(typ: Type): Option[(Type, Type)] = typ match {
case TApp(TApp(TBuiltin(BTArrow), targ), tres) => Some((targ, tres))
case _ => None
}
}

class ParametricType1(bType: BuiltinType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ test = script do

let acs = toACS cert <> toACS upgradeAgreement

(_, commands) <- testRule upgradeTrigger alice acs Map.empty ()
(_, commands) <- testRule upgradeTrigger alice [] acs Map.empty ()
let flatCommands = flattenCommands commands
assertExerciseCmd flatCommands $ \(cid, choiceArg) -> do
cid === upgradeAgreement
Expand Down
36 changes: 22 additions & 14 deletions triggers/daml/Daml/Trigger.daml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ module Daml.Trigger
, RegisteredTemplates(..)
, registeredTemplate
, RelTime(..)
, getReadAs
) where

import Prelude hiding (any)
Expand Down Expand Up @@ -116,20 +117,27 @@ class ActionTriggerAny m where
-- `emitCommands`.
queryPendingContracts : m [AnyContractId]

getReadAs : m [Party]

instance ActionTriggerAny (TriggerA s) where
implQuery = TriggerA $ pure . getContracts
queryContractId id = TriggerA $ pure . getContractById id
queryPendingContracts = TriggerA $ \acs -> pure (getPendingContracts acs)
getReadAs = TriggerA $ \_ -> do
s <- get
pure s.readAs

instance ActionTriggerAny (TriggerUpdateA s) where
implQuery = TriggerUpdateA $ pure . getContracts . snd
queryContractId id = TriggerUpdateA $ pure . getContractById id . snd
queryPendingContracts = TriggerUpdateA $ \(_, acs) -> pure (getPendingContracts acs)
implQuery = TriggerUpdateA $ \(_, acs, _) -> pure (getContracts acs)
queryContractId id = TriggerUpdateA $ \(_, acs, _) -> pure (getContractById id acs)
queryPendingContracts = TriggerUpdateA $ \(_, acs, _) -> pure (getPendingContracts acs)
getReadAs = TriggerUpdateA $ \(_, _, readAs) -> pure readAs

instance ActionTriggerAny TriggerInitializeA where
implQuery = TriggerInitializeA getContracts
queryContractId = TriggerInitializeA . getContractById
queryPendingContracts = TriggerInitializeA getPendingContracts
implQuery = TriggerInitializeA (\(acs, _) -> getContracts acs)
queryContractId id = TriggerInitializeA (\(acs, _) -> getContractById id acs)
queryPendingContracts = TriggerInitializeA (\(acs, _) -> getPendingContracts acs)
getReadAs = TriggerInitializeA (\(_, readAs) -> readAs)

-- | Features possible in `updateState` and `rule`.
class ActionTriggerAny m => ActionTriggerUpdate m where
Expand All @@ -140,7 +148,7 @@ class ActionTriggerAny m => ActionTriggerUpdate m where
getCommandsInFlight : m (Map CommandId [Command])

instance ActionTriggerUpdate (TriggerUpdateA s) where
getCommandsInFlight = TriggerUpdateA $ \(cif, _) -> pure cif
getCommandsInFlight = TriggerUpdateA $ \(cif, _, _) -> pure cif

instance ActionTriggerUpdate (TriggerA s) where
getCommandsInFlight = liftTriggerRule $ get <&> \s -> s.commandsInFlight
Expand Down Expand Up @@ -266,20 +274,20 @@ runTrigger userTrigger = LowLevel.Trigger
, heartbeat = userTrigger.heartbeat
}
where
initialState party (ActiveContracts createdEvents) =
initialState party readAs (ActiveContracts createdEvents) =
let acs = foldl (\acs created -> applyEvent (CreatedEvent created) acs) (ACS mempty Map.empty) createdEvents
userState = runTriggerInitializeA userTrigger.initialize acs
state = TriggerState acs party userState Map.empty
userState = runTriggerInitializeA userTrigger.initialize (acs, readAs)
state = TriggerState acs party readAs userState Map.empty
in TriggerSetup $ execStateT (runTriggerRule $ runRule userTrigger.rule) state
utUpdateState commandsInFlight acs msg = execState $ flip runTriggerUpdateA (commandsInFlight, acs) $ userTrigger.updateState msg
utUpdateState commandsInFlight acs readAs msg = execState $ flip runTriggerUpdateA (commandsInFlight, acs, readAs) $ userTrigger.updateState msg
update msg = do
time <- getTime
state <- get
case msg of
MCompletion completion ->
-- NB: the commands-in-flight and ACS updateState sees are those
-- prior to updates incurred by the msg
let userState = utUpdateState state.commandsInFlight state.acs (MCompletion completion) state.userState
let userState = utUpdateState state.commandsInFlight state.acs state.readAs (MCompletion completion) state.userState
in case completion.status of
Succeeded {} ->
-- We delete successful completions when we receive the corresponding transaction
Expand All @@ -293,14 +301,14 @@ runTrigger userTrigger = LowLevel.Trigger
MTransaction transaction -> do
let acs = applyTransaction transaction state.acs
-- again, we use the commands-in-flight and ACS before the update below
userState = utUpdateState state.commandsInFlight acs (MTransaction transaction) state.userState
userState = utUpdateState state.commandsInFlight acs state.readAs (MTransaction transaction) state.userState
-- See the comment above for why we delete this here instead of when we receive the completion.
(acs', commandsInFlight) = case transaction.commandId of
None -> (acs, state.commandsInFlight)
Some commandId -> (acs { pendingContracts = Map.delete commandId acs.pendingContracts }, Map.delete commandId state.commandsInFlight)
put $ state { acs = acs', userState, commandsInFlight }
runRule userTrigger.rule
MHeartbeat -> do
let userState = utUpdateState state.commandsInFlight state.acs MHeartbeat state.userState
let userState = utUpdateState state.commandsInFlight state.acs state.readAs MHeartbeat state.userState
put $ state { userState }
runRule userTrigger.rule
4 changes: 3 additions & 1 deletion triggers/daml/Daml/Trigger/Assert.daml
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,18 @@ toACS cid = ACSBuilder $ \p ->
testRule
: Trigger s -- ^ Test this trigger's 'Trigger.rule'.
-> Party -- ^ Execute the rule as this 'Party'.
-> [Party] -- ^ Execute the rule with these parties as `readAs`
-> ACSBuilder -- ^ List these contracts in the 'ACS'.
-> Map CommandId [Command] -- ^ The commands in flight.
-> s -- ^ The trigger state.
-> Script (s, [Commands]) -- ^ The 'Commands' and new state emitted by the rule. The 'CommandId's will start from @"0"@.
testRule trigger party acsBuilder commandsInFlight s = do
testRule trigger party readAs acsBuilder commandsInFlight s = do
time <- getTime
acs <- buildACS party acsBuilder
let state = TriggerState
{ acs = acs
, party = party
, readAs = readAs
, userState = s
, commandsInFlight = commandsInFlight
}
Expand Down
12 changes: 8 additions & 4 deletions triggers/daml/Daml/Trigger/Internal.daml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ instance HasTime (TriggerA s) where
-- cannot use `emitCommands` or `getTime`.
newtype TriggerUpdateA s a =
-- | HIDE
TriggerUpdateA { runTriggerUpdateA : (Map CommandId [Command], ACS) -> State s a }
TriggerUpdateA { runTriggerUpdateA : (Map CommandId [Command], ACS, [Party]) -> State s a }

instance Functor (TriggerUpdateA s) where
fmap f (TriggerUpdateA r) = TriggerUpdateA $ rliftFmap fmap f r
Expand All @@ -102,7 +102,7 @@ instance ActionState s (TriggerUpdateA s) where
-- trigger. It can query, but not emit commands or update the state.
newtype TriggerInitializeA a =
-- | HIDE
TriggerInitializeA { runTriggerInitializeA : ACS -> a }
TriggerInitializeA { runTriggerInitializeA : (ACS, [Party]) -> a }
deriving (Functor, Applicative, Action)

-- Internal API
Expand Down Expand Up @@ -171,12 +171,13 @@ runRule rule = do
state <- get
TriggerRule . zoom zoomIn zoomOut . runTriggerRule . flip runTriggerA state.acs
$ rule state.party
where zoomIn state = TriggerAState state.commandsInFlight state.acs.pendingContracts state.userState
where zoomIn state = TriggerAState state.commandsInFlight state.acs.pendingContracts state.userState state.readAs
zoomOut state aState =
let commandsInFlight = aState.commandsInFlight
acs = state.acs { pendingContracts = aState.pendingContracts }
userState = aState.userState
in state { commandsInFlight, acs, userState }
readAs = aState.readAs
in state { commandsInFlight, acs, userState, readAs }

-- | HIDE
liftTriggerRule : TriggerRule (TriggerAState s) a -> TriggerA s a
Expand All @@ -192,12 +193,15 @@ data TriggerAState s = TriggerAState
-- zoomed from TriggerState's acs.
, userState : s
-- ^ zoomed from TriggerState
, readAs : [Party]
-- ^ zoomed from TriggerState
}

-- | HIDE
data TriggerState s = TriggerState
{ acs : ACS
, party : Party
, readAs : [Party]
, userState : s
, commandsInFlight : Map CommandId [Command]
}
Expand Down
2 changes: 1 addition & 1 deletion triggers/daml/Daml/Trigger/LowLevel.daml
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ data ActiveContracts = ActiveContracts { activeContracts : [Created] }
-- | Trigger is (approximately) a left-fold over `Message` with
-- an accumulator of type `s`.
data Trigger s = Trigger
{ initialState : Party -> ActiveContracts -> TriggerSetup s
{ initialState : Party -> [Party] -> ActiveContracts -> TriggerSetup s
, update : Message -> TriggerRule s ()
, registeredTemplates : RegisteredTemplates
, heartbeat : Optional RelTime
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@ import com.daml.ledger.api.v1.transaction_filter.{Filters, InclusiveFilters, Tra
import com.daml.ledger.client.LedgerClient
import com.daml.ledger.client.services.commands.CompletionStreamElement._
import com.daml.lf.archive.Dar
import com.daml.lf.data.FrontStack
import com.daml.lf.data.ImmArray
import com.daml.lf.data.Ref
import com.daml.lf.data.Ref._
import com.daml.lf.data.ScalazEqual._
import com.daml.lf.data.Time.Timestamp
import com.daml.lf.language.Ast._
import com.daml.lf.language.PackageInterface
import com.daml.lf.language.Util._
import com.daml.lf.speedy.SExpr._
import com.daml.lf.speedy.SResult._
import com.daml.lf.speedy.SValue._
Expand Down Expand Up @@ -71,6 +74,8 @@ final case class Trigger(
// TransactionFilter since the latter is
// party-specific.
heartbeat: Option[FiniteDuration],
// Whether the trigger supports readAs claims (SDK 1.18 and newer) or not.
hasReadAs: Boolean,
) {
@nowarn("msg=parameter value label .* is never used") // Proxy only
private[trigger] final class withLoggingContext[P] private (
Expand Down Expand Up @@ -112,6 +117,26 @@ object Machine extends StrictLogging {
}

object Trigger extends StrictLogging {

private def detectHasReadAs(
interface: PackageInterface,
triggerIds: TriggerIds,
): Either[String, Boolean] =
for {
fieldInfo <- interface
.lookupRecordFieldInfo(
triggerIds.damlTriggerLowLevel("Trigger"),
Name.assertFromString("initialState"),
)
.left
.map(_.pretty)
hasReadAs <- fieldInfo.typDef match {
case TFun(TParty, TFun(TList(TParty), TFun(_, _))) => Right(true)
case TFun(TParty, TFun(_, _)) => Right(false)
case t => Left(s"Internal error: Unexpected type for initialState function: $t")
}
} yield hasReadAs

def fromIdentifier(
compiledPackages: CompiledPackages,
triggerId: Identifier,
Expand Down Expand Up @@ -154,10 +179,11 @@ object Trigger extends StrictLogging {
case _ => Left(s"Trigger must points to a value but points to $definition")
}
triggerIds = TriggerIds(expr.ty.tycon.packageId)
hasReadAs <- detectHasReadAs(compiledPackages.interface, triggerIds)
converter: Converter = Converter(compiledPackages, triggerIds)
filter <- getTriggerFilter(compiledPackages, compiler, converter, expr)
heartbeat <- getTriggerHeartbeat(compiledPackages, compiler, converter, expr)
} yield Trigger(expr, triggerId, triggerIds, filter, heartbeat)
} yield Trigger(expr, triggerId, triggerIds, filter, heartbeat, hasReadAs)
}

// Return the heartbeat specified by the user.
Expand Down Expand Up @@ -220,7 +246,7 @@ class Runner(
client: LedgerClient,
timeProviderType: TimeProviderType,
applicationId: ApplicationId,
party: Party,
parties: TriggerParties,
)(implicit loggingContext: LoggingContextOf[Trigger]) {
import Runner.{SeenMsgs, alterF}

Expand All @@ -233,7 +259,7 @@ class Runner(
// message, or both.
private[this] var pendingCommandIds: Map[UUID, SeenMsgs] = Map.empty
private val transactionFilter: TransactionFilter =
TransactionFilter(Seq((party.unwrap, trigger.filters)).toMap)
TransactionFilter(parties.readers.map(p => (p.unwrap, trigger.filters)).toMap)

private[this] def logger = ContextualizedLogger get getClass

Expand Down Expand Up @@ -261,7 +287,8 @@ class Runner(
ledgerId = client.ledgerId.unwrap,
applicationId = applicationId.unwrap,
commandId = commandUUID.toString,
party = party.unwrap,
party = parties.actAs.unwrap,
readAs = Party.unsubst(parties.readAs).toList,
commands = commands,
)
logger.debug(
Expand Down Expand Up @@ -322,7 +349,7 @@ class Runner(
client: LedgerClient,
offset: LedgerOffset,
heartbeat: Option[FiniteDuration],
party: Party,
parties: TriggerParties,
filter: TransactionFilter,
): Flow[SingleCommandFailure, TriggerMsg, NotUsed] = {

Expand Down Expand Up @@ -362,7 +389,8 @@ class Runner(
submissionFailureQueue
.merge(
client.commandClient
.completionSource(List(party.unwrap), offset)
// Completions only take actAs into account so no need to include readAs.
.completionSource(List(parties.actAs.unwrap), offset)
.mapConcat {
case CheckpointElement(_) => List()
case CompletionElement(c) => List(c)
Expand Down Expand Up @@ -403,10 +431,17 @@ class Runner(
// Convert the ACS to a speedy value.
val createdValue: SValue = converter.fromACS(acs).orConverterException
// Setup an application expression of initialState on the ACS.
val partyArg = SParty(Ref.Party.assertFromString(parties.actAs.unwrap))
val initialStateArgs = if (trigger.hasReadAs) {
val readAsArg = SList(
parties.readAs.map(p => SParty(Ref.Party.assertFromString(p.unwrap))).to(FrontStack)
)
Array(partyArg, readAsArg, createdValue)
} else Array(partyArg, createdValue)
val initialState: SExpr =
makeApp(
getInitialState,
Array(SParty(Ref.Party.assertFromString(party.unwrap)), createdValue),
initialStateArgs,
)
// Prepare a speedy machine for evaluating expressions.
val machine: Speedy.Machine =
Expand Down Expand Up @@ -448,7 +483,8 @@ class Runner(
name: String,
acs: Seq[CreatedEvent],
): Flow[TriggerMsg, SubmitRequest, Future[SValue]] = {
logger.info(s"Trigger ${name} is running as ${party}")
logger.info(s"""Trigger ${name} is running as ${parties.actAs} with readAs=[${parties.readAs
.mkString(", ")}]""")

val clientTime: Timestamp =
Timestamp.assertFromInstant(Runner.getTimeProvider(timeProviderType).getCurrentTime)
Expand Down Expand Up @@ -617,7 +653,7 @@ class Runner(
executionContext: ExecutionContext,
): (T, Future[SValue]) = {
val source =
msgSource(client, offset, trigger.heartbeat, party, transactionFilter)
msgSource(client, offset, trigger.heartbeat, parties, transactionFilter)
Flow
.fromGraph(msgFlow)
.viaMat(getTriggerEvaluator(name, acs))(Keep.both)
Expand Down Expand Up @@ -720,7 +756,7 @@ object Runner extends StrictLogging {
client: LedgerClient,
timeProviderType: TimeProviderType,
applicationId: ApplicationId,
party: Party,
parties: TriggerParties,
config: Compiler.Config,
)(implicit materializer: Materializer, executionContext: ExecutionContext): Future[SValue] = {
val darMap = dar.all.toMap
Expand All @@ -734,7 +770,7 @@ object Runner extends StrictLogging {
}
val runner =
trigger.withLoggingContext { implicit lc =>
new Runner(compiledPackages, trigger, client, timeProviderType, applicationId, party)
new Runner(compiledPackages, trigger, client, timeProviderType, applicationId, parties)
}
for {
(acs, offset) <- runner.queryACS()
Expand Down
Loading

0 comments on commit 07ad3e0

Please sign in to comment.