Skip to content

Commit

Permalink
Merge pull request #637 from zkao/defer_reqs
Browse files Browse the repository at this point in the history
swapd: improve PendingRequests and add associated fn defer_request
  • Loading branch information
TheCharlatan authored Aug 11, 2022
2 parents 1ef68fa + 72e7ef8 commit f11f14c
Showing 1 changed file with 125 additions and 110 deletions.
235 changes: 125 additions & 110 deletions src/swapd/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,16 +210,71 @@ pub struct Runtime {
}

// FIXME Something more meaningful than ServiceId to index
#[derive(Default)]
pub struct PendingRequests(HashMap<ServiceId, Vec<PendingRequest>>);
type PendingRequests = HashMap<ServiceId, Vec<PendingRequest>>;

impl From<HashMap<ServiceId, Vec<PendingRequest>>> for PendingRequests {
fn from(m: HashMap<ServiceId, Vec<PendingRequest>>) -> Self {
PendingRequests(m)
impl PendingRequestsT for PendingRequests {
fn defer_request(&mut self, key: ServiceId, pending_req: PendingRequest) {
let pending_reqs = self.entry(key).or_insert(vec![]);
pending_reqs.push(pending_req);
}
fn continue_deferred_requests(
runtime: &mut Runtime,
endpoints: &mut Endpoints,
key: ServiceId,
predicate: fn(&PendingRequest) -> bool,
) -> bool {
let success = if let Some(pending_reqs) = runtime.pending_requests.remove(&key) {
let len0 = pending_reqs.len();
let remaining_pending_reqs: Vec<_> = pending_reqs
.into_iter()
.filter_map(|r| {
if predicate(&r) {
if let Ok(_) = match &r.bus_id {
ServiceBus::Ctl if &r.dest == &runtime.identity => runtime
.handle_rpc_ctl(endpoints, r.source.clone(), r.request.clone()),
ServiceBus::Msg if &r.dest == &runtime.identity => runtime
.handle_rpc_msg(endpoints, r.source.clone(), r.request.clone()),
_ => endpoints
.send_to(
r.bus_id.clone(),
r.source.clone(),
r.dest.clone(),
r.request.clone(),
)
.map_err(Into::into),
} {
None
} else {
Some(r)
}
} else {
Some(r)
}
})
.collect();
let len1 = remaining_pending_reqs.len();
runtime.pending_requests.insert(key, remaining_pending_reqs);
if len0 - len1 > 1 {
error!("consumed more than one request with this predicate")
}
len0 > len1
} else {
error!("no request consumed with this predicate");
false
};
success
}
}

impl PendingRequests {}
trait PendingRequestsT {
fn defer_request(&mut self, key: ServiceId, pending_req: PendingRequest);
fn continue_deferred_requests(
runtime: &mut Runtime, // needed for recursion
endpoints: &mut Endpoints,
key: ServiceId,
predicate: fn(&PendingRequest) -> bool,
) -> bool;
}

#[derive(Debug, Clone)]
pub struct PendingRequest {
Expand All @@ -238,10 +293,6 @@ impl PendingRequest {
request,
}
}
fn defer_request(self, pending_requests: &mut PendingRequests, key: ServiceId) {
let pending_reqs = pending_requests.0.entry(key).or_insert(vec![]);
pending_reqs.push(self);
}
}

impl StrictEncode for PendingRequest {
Expand Down Expand Up @@ -467,56 +518,7 @@ impl Runtime {
}

fn pending_requests(&mut self) -> &mut HashMap<ServiceId, Vec<PendingRequest>> {
&mut self.pending_requests.0
}

fn continue_deferred_requests(
&mut self,
endpoints: &mut Endpoints,
key: ServiceId,
predicate: fn(&PendingRequest) -> bool,
) -> bool {
let success = if let Some(pending_reqs) = self.pending_requests().remove(&key) {
let len0 = pending_reqs.len();
let remaining_pending_reqs: Vec<_> =
pending_reqs
.into_iter()
.filter_map(|r| {
if predicate(&r) {
if let Ok(_) = match &r.bus_id {
ServiceBus::Ctl if &r.dest == &self.identity => self
.handle_rpc_ctl(endpoints, r.source.clone(), r.request.clone()),
ServiceBus::Msg if &r.dest == &self.identity => self
.handle_rpc_msg(endpoints, r.source.clone(), r.request.clone()),
_ => endpoints
.send_to(
r.bus_id.clone(),
r.source.clone(),
r.dest.clone(),
r.request.clone(),
)
.map_err(Into::into),
} {
None
} else {
Some(r)
}
} else {
Some(r)
}
})
.collect();
let len1 = remaining_pending_reqs.len();
self.pending_requests().insert(key, remaining_pending_reqs);
if len0 - len1 > 1 {
error!("consumed more than one request with this predicate")
}
len0 > len1
} else {
error!("no request consumed with this predicate");
false
};
success
&mut self.pending_requests
}

fn state_update(&mut self, endpoints: &mut Endpoints, next_state: State) -> Result<(), Error> {
Expand Down Expand Up @@ -631,8 +633,9 @@ impl Runtime {
ServiceBus::Msg,
request,
);
pending_request
.defer_request(&mut self.pending_requests, ServiceId::Wallet);

self.pending_requests
.defer_request(ServiceId::Wallet, pending_request);
}
SwapRole::Alice => {
debug!("Alice: forwarding reveal");
Expand Down Expand Up @@ -661,10 +664,8 @@ impl Runtime {
&request
);
let pending_req = PendingRequest::new(source, self.identity(), msg_bus, request);
pending_req.defer_request(
&mut self.pending_requests,
self.syncer_state.bitcoin_syncer(),
);
self.pending_requests
.defer_request(self.syncer_state.bitcoin_syncer(), pending_req);
}
// bob and alice
// store parameters from counterparty if we have not received them yet.
Expand Down Expand Up @@ -729,7 +730,8 @@ impl Runtime {
msg_bus,
request,
);
pending_req.defer_request(&mut self.pending_requests, ServiceId::Wallet);
self.pending_requests
.defer_request(ServiceId::Wallet, pending_req);
}
_ => unreachable!(
"Bob btc_fee_estimate_sat_per_kvb.is_none() was handled previously"
Expand Down Expand Up @@ -936,7 +938,7 @@ impl Runtime {
let dest = self.syncer_state.monero_syncer();
let pending_request =
PendingRequest::new(self.identity(), dest.clone(), ServiceBus::Ctl, request);
pending_request.defer_request(&mut self.pending_requests, dest);
self.pending_requests.defer_request(dest, pending_request);
}
Request::TakeSwap(InitSwap {
peerd,
Expand Down Expand Up @@ -1055,8 +1057,11 @@ impl Runtime {
.map(|reqs| reqs.len() == 2)
.unwrap() =>
{
let success_proof =
self.continue_deferred_requests(endpoints, source.clone(), |r| {
let success_proof = PendingRequests::continue_deferred_requests(
self,
endpoints,
source.clone(),
|r| {
matches!(
r,
&PendingRequest {
Expand All @@ -1066,22 +1071,24 @@ impl Runtime {
..
}
)
});
},
);
if !success_proof {
error!("Did not dispatch proof pending request");
}

let success_params = self.continue_deferred_requests(endpoints, source, |r| {
matches!(
r,
&PendingRequest {
dest: ServiceId::Wallet,
bus_id: ServiceBus::Msg,
request: Request::Protocol(Msg::Reveal(Reveal::AliceParameters(_))),
..
}
)
});
let success_params =
PendingRequests::continue_deferred_requests(self, endpoints, source, |r| {
matches!(
r,
&PendingRequest {
dest: ServiceId::Wallet,
bus_id: ServiceBus::Msg,
request: Request::Protocol(Msg::Reveal(Reveal::AliceParameters(_))),
..
}
)
});
if !success_params {
error!("Did not dispatch params pending requests");
}
Expand Down Expand Up @@ -1270,16 +1277,21 @@ impl Runtime {
.unwrap() =>
{
// error!("not checking tx rcvd is accordant lock");
let success = self.continue_deferred_requests(endpoints, source, |r| {
matches!(
r,
&PendingRequest {
bus_id: ServiceBus::Msg,
request: Request::Protocol(Msg::BuyProcedureSignature(_)),
..
}
)
});
let success = PendingRequests::continue_deferred_requests(
self,
endpoints,
source,
|r| {
matches!(
r,
&PendingRequest {
bus_id: ServiceBus::Msg,
request: Request::Protocol(Msg::BuyProcedureSignature(_)),
..
}
)
},
);
if success {
let next_state = State::Bob(BobState::BuySigB {
buy_tx_seen: false,
Expand Down Expand Up @@ -1990,19 +2002,24 @@ impl Runtime {
&& self.state.b_address().is_some()
&& self.syncer_state.btc_fee_estimate_sat_per_kvb.is_some()
{
let success = self.continue_deferred_requests(endpoints, source, |i| {
matches!(
&i,
&PendingRequest {
bus_id: ServiceBus::Msg,
request: Request::Protocol(Msg::Reveal(
Reveal::AliceParameters(..)
)),
dest: ServiceId::Swap(..),
..
}
)
});
let success = PendingRequests::continue_deferred_requests(
self,
endpoints,
source,
|i| {
matches!(
&i,
&PendingRequest {
bus_id: ServiceBus::Msg,
request: Request::Protocol(Msg::Reveal(
Reveal::AliceParameters(..)
)),
dest: ServiceId::Swap(..),
..
}
)
},
);
if success {
debug!("successfully dispatched reveal:aliceparams")
} else {
Expand Down Expand Up @@ -2258,10 +2275,8 @@ impl Runtime {
ServiceBus::Msg,
request,
);
pending_request.defer_request(
&mut self.pending_requests,
self.syncer_state.monero_syncer(),
);
self.pending_requests
.defer_request(self.syncer_state.monero_syncer(), pending_request);
}

Request::AbortSwap
Expand Down Expand Up @@ -2377,7 +2392,7 @@ impl Runtime {
self.state = state;
self.enquirer = enquirer;
self.temporal_safety = temporal_safety;
self.pending_requests = PendingRequests(pending_requests);
self.pending_requests = pending_requests;
self.txs = txs.clone();
trace!("Watch height bitcoin");
let watch_height_bitcoin = self.syncer_state.watch_height(Blockchain::Bitcoin);
Expand Down

0 comments on commit f11f14c

Please sign in to comment.