Skip to content

Commit

Permalink
refactor: dispatch by worker to allow atomic cursors (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
scarmuega authored Nov 5, 2024
1 parent 09672d5 commit 50389ad
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 132 deletions.
186 changes: 98 additions & 88 deletions balius-runtime/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use pallas::ledger::traverse::MultiEraBlock;
use router::Router;
use serde_json::json;
use std::{
collections::{HashMap, HashSet},
Expand Down Expand Up @@ -89,6 +90,12 @@ impl From<redb::TableError> for Error {
}
}

impl From<redb::CommitError> for Error {
fn from(value: redb::CommitError) -> Self {
Self::Store(value.into())
}
}

impl From<redb::StorageError> for Error {
fn from(value: redb::StorageError) -> Self {
Self::Store(value.into())
Expand Down Expand Up @@ -124,15 +131,66 @@ impl wit::balius::app::driver::Host for WorkerState {
id: u32,
pattern: wit::balius::app::driver::EventPattern,
) -> () {
self.router.register_channel(&self.worker_id, id, &pattern);
self.router.register_channel(id, &pattern);
}
}

struct LoadedWorker {
store: wasmtime::Store<WorkerState>,
wasm_store: wasmtime::Store<WorkerState>,
instance: wit::Worker,
}

impl LoadedWorker {
pub async fn dispatch_event(
&mut self,
channel: u32,
event: &wit::Event,
) -> Result<wit::Response, Error> {
self.instance
.call_handle(&mut self.wasm_store, channel, event)
.await?
.map_err(|err| Error::Handle(err.code, err.message))
}

async fn acknowledge_event(&mut self, channel: u32, event: &wit::Event) -> Result<(), Error> {
let result = self.dispatch_event(channel, event).await;

match result {
Ok(wit::Response::Acknowledge) => {
tracing::debug!("worker acknowledge");
}
Ok(_) => {
tracing::warn!("worker returned unexpected data");
}
Err(Error::Handle(code, message)) => {
tracing::warn!(code, message);
}
Err(e) => return Err(e),
}

Ok(())
}

async fn apply_block(
&mut self,
block: &MultiEraBlock<'_>,
log_seq: LogSeq,
) -> Result<(), Error> {
for tx in block.txs() {
for (_, utxo) in tx.produces() {
let event = wit::Event::Utxo(utxo.encode());
let channels = self.wasm_store.data().router.find_utxo_targets(&utxo)?;

for channel in channels {
self.acknowledge_event(channel, &event).await?;
}
}
}

Ok(())
}
}

type WorkerMap = HashMap<String, LoadedWorker>;

#[derive(Clone)]
Expand All @@ -141,7 +199,6 @@ pub struct Runtime {
linker: wasmtime::component::Linker<WorkerState>,
loaded: Arc<Mutex<WorkerMap>>,

router: router::Router,
store: store::Store,
ledger: Option<ledgers::Ledger>,
kv: Option<kv::Kv>,
Expand All @@ -153,10 +210,9 @@ impl Runtime {
RuntimeBuilder::new(store)
}

pub fn cursor(&self) -> Result<Option<LogSeq>, Error> {
let cursor = self.store.lowest_cursor()?;

Ok(cursor)
pub fn chain_cursor(&self) -> Result<Option<ChainPoint>, Error> {
// TODO: iterate over all workers and find the lowest cursor
todo!()
}

pub async fn register_worker(
Expand All @@ -167,122 +223,77 @@ impl Runtime {
) -> Result<(), Error> {
let component = wasmtime::component::Component::from_file(&self.engine, wasm_path)?;

let mut store = wasmtime::Store::new(
let mut wasm_store = wasmtime::Store::new(
&self.engine,
WorkerState {
worker_id: id.to_owned(),
router: self.router.clone(),
router: Router::new(),
ledger: self.ledger.clone(),
kv: self.kv.clone(),
submit: self.submit.clone(),
},
);

let instance = wit::Worker::instantiate_async(&mut store, &component, &self.linker).await?;
let instance =
wit::Worker::instantiate_async(&mut wasm_store, &component, &self.linker).await?;

let config = serde_json::to_vec(&config).unwrap();
instance.call_init(&mut store, &config).await?;
instance.call_init(&mut wasm_store, &config).await?;

self.loaded
.lock()
.await
.insert(id.to_owned(), LoadedWorker { store, instance });
self.loaded.lock().await.insert(
id.to_owned(),
LoadedWorker {
wasm_store,
instance,
},
);

Ok(())
}

pub async fn dispatch_event(
&self,
worker: &str,
channel: u32,
event: &wit::Event,
) -> Result<wit::Response, Error> {
let mut lock = self.loaded.lock().await;
pub async fn apply_block(&self, block: &MultiEraBlock<'_>) -> Result<(), Error> {
let log_seq = self.store.write_ahead(block)?;

let worker = lock
.get_mut(worker)
.ok_or(Error::WorkerNotFound(worker.to_string()))?;

let result = worker
.instance
.call_handle(&mut worker.store, channel, event)
.await?;
let mut lock = self.loaded.lock().await;

let response = result.map_err(|err| Error::Handle(err.code, err.message))?;
let mut atomic_update = self.store.start_atomic_update()?;

Ok(response)
}

async fn fire_and_forget(
&self,
event: &wit::Event,
targets: HashSet<router::Target>,
) -> Result<(), Error> {
for target in targets {
let result = self
.dispatch_event(&target.worker, target.channel, event)
.await;

match result {
Ok(wit::Response::Acknowledge) => {
tracing::debug!(worker = target.worker, "worker acknowledge");
}
Ok(_) => {
tracing::warn!(worker = target.worker, "worker returned unexpected data");
}
Err(Error::Handle(code, message)) => {
tracing::warn!(code, message);
}
Err(e) => return Err(e),
}
for (_, worker) in lock.iter_mut() {
worker.apply_block(block, log_seq).await?;
atomic_update.set_worker_cursor(&worker.wasm_store.data().worker_id, log_seq)?;
}

Ok(())
}

pub async fn apply_block(
&self,
block: &MultiEraBlock<'_>,
wal_seq: LogSeq,
) -> Result<(), Error> {
for tx in block.txs() {
for utxo in tx.outputs() {
let targets = self.router.find_utxo_targets(&utxo)?;
let event = wit::Event::Utxo(utxo.encode());

self.fire_and_forget(&event, targets).await?;
}
}
atomic_update.commit()?;

Ok(())
}

pub fn undo_block(&self, block: &MultiEraBlock, wal_seq: LogSeq) -> Result<(), Error> {
// TODO: implement undo once we have "apply" working
pub async fn undo_block(&self, block: &MultiEraBlock<'_>) -> Result<(), Error> {
Ok(())
}

pub async fn handle_request(
&self,
worker: &str,
method: &str,
params: serde_json::Value,
) -> Result<serde_json::Value, Error> {
let target = self.router.find_request_target(worker, method)?;
params: Vec<u8>,
) -> Result<wit::Response, Error> {
let mut lock = self.loaded.lock().await;

let evt = wit::Event::Request(serde_json::to_vec(&params).unwrap());
let worker = lock
.get_mut(worker)
.ok_or(Error::WorkerNotFound(worker.to_string()))?;

let reply = self
.dispatch_event(&target.worker, target.channel, &evt)
.await?;
let channel = worker
.wasm_store
.data()
.router
.find_request_target(method)?;

let json = match reply {
wit::Response::Acknowledge => json!({}),
wit::Response::Json(x) => serde_json::from_slice(&x).unwrap(),
wit::Response::Cbor(x) => json!({ "cbor": x }),
wit::Response::PartialTx(x) => json!({ "tx": x }),
};
let evt = wit::Event::Request(params);

Ok(json)
worker.dispatch_event(channel, &evt).await
}
}

Expand Down Expand Up @@ -360,7 +371,6 @@ impl RuntimeBuilder {

Ok(Runtime {
loaded: Default::default(),
router: router::Router::new(),
engine,
linker,
store,
Expand Down
58 changes: 22 additions & 36 deletions balius-runtime/src/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{

use pallas::ledger::traverse::MultiEraOutput;

use crate::wit::balius::app::driver::EventPattern;
use crate::wit::balius::app::driver::{Event, EventPattern};

type WorkerId = String;
type ChannelId = u32;
Expand All @@ -14,62 +14,49 @@ type AddressBytes = Vec<u8>;

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
enum MatchKey {
RequestMethod(WorkerId, Method),
RequestMethod(Method),
UtxoAddress(AddressBytes),
}

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Target {
pub channel: ChannelId,
pub worker: String,
}

fn infer_match_keys(worker: &str, pattern: &EventPattern) -> Vec<MatchKey> {
fn infer_match_keys(pattern: &EventPattern) -> Vec<MatchKey> {
match pattern {
EventPattern::Request(x) => vec![MatchKey::RequestMethod(worker.to_owned(), x.to_owned())],
EventPattern::Request(x) => vec![MatchKey::RequestMethod(x.to_owned())],
EventPattern::Utxo(_) => todo!(),
EventPattern::UtxoUndo(_) => todo!(),
EventPattern::Timer(_) => todo!(),
EventPattern::Message(_) => todo!(),
}
}

type RouteMap = HashMap<MatchKey, HashSet<Target>>;
type RouteMap = HashMap<MatchKey, HashSet<ChannelId>>;

#[derive(Default, Clone)]
pub struct Router {
routes: Arc<RwLock<RouteMap>>,
routes: RouteMap,
}

impl Router {
pub fn new() -> Self {
Self {
routes: Arc::new(RwLock::new(Default::default())),
}
Default::default()
}

pub fn register_channel(&mut self, worker: &str, channel: u32, pattern: &EventPattern) {
let keys = infer_match_keys(worker, pattern);
let mut routes = self.routes.write().unwrap();
pub fn register_channel(&mut self, channel: u32, pattern: &EventPattern) {
let keys = infer_match_keys(pattern);

for key in keys {
let targets = routes.entry(key).or_default();
let targets = self.routes.entry(key).or_default();

targets.insert(Target {
worker: worker.to_string(),
channel,
});
targets.insert(channel);
}
}

pub fn find_utxo_targets(
&self,
utxo: &MultiEraOutput,
) -> Result<HashSet<Target>, super::Error> {
let routes = self.routes.read().unwrap();

) -> Result<HashSet<ChannelId>, super::Error> {
let key = MatchKey::UtxoAddress(utxo.address()?.to_vec());
let targets: HashSet<_> = routes
let targets: HashSet<_> = self
.routes
.get(&key)
.iter()
.flat_map(|x| x.iter())
Expand All @@ -81,11 +68,10 @@ impl Router {
Ok(targets)
}

pub fn find_request_target(&self, worker: &str, method: &str) -> Result<Target, super::Error> {
let key = MatchKey::RequestMethod(worker.to_owned(), method.to_owned());
let routes = self.routes.read().unwrap();
pub fn find_request_target(&self, method: &str) -> Result<ChannelId, super::Error> {
let key = MatchKey::RequestMethod(method.to_owned());

let targets = routes.get(&key).ok_or(super::Error::NoTarget)?;
let targets = self.routes.get(&key).ok_or(super::Error::NoTarget)?;

if targets.is_empty() {
return Err(super::Error::NoTarget);
Expand All @@ -108,14 +94,14 @@ mod tests {
#[test]
fn test_request_channel() {
let mut router = Router::new();
let worker = "test_worker";

let method = "test_method";
let channel = 1;

router.register_channel(worker, channel, &EventPattern::Request(method.to_string()));
router.register_channel(channel, &EventPattern::Request(method.to_string()));

let channel = router.find_request_target(method).unwrap();

let target = router.find_request_target(worker, method).unwrap();
assert_eq!(target.worker, worker);
assert_eq!(target.channel, channel);
assert_eq!(channel, channel);
}
}
Loading

0 comments on commit 50389ad

Please sign in to comment.