diff --git a/src/swapd/runtime.rs b/src/swapd/runtime.rs index ec29791ff..dcfc7fc8b 100644 --- a/src/swapd/runtime.rs +++ b/src/swapd/runtime.rs @@ -210,16 +210,71 @@ pub struct Runtime { } // FIXME Something more meaningful than ServiceId to index -#[derive(Default)] -pub struct PendingRequests(HashMap>); +type PendingRequests = HashMap>; -impl From>> for PendingRequests { - fn from(m: HashMap>) -> Self { - PendingRequests(m) +impl PendingRequestsT for PendingRequests { + fn defer_request(&mut self, key: ServiceId, pending_req: PendingRequest) { + let pending_reqs = self.entry(key).or_insert(vec![]); + pending_reqs.push(pending_req); + } + fn continue_deferred_requests( + runtime: &mut Runtime, + endpoints: &mut Endpoints, + key: ServiceId, + predicate: fn(&PendingRequest) -> bool, + ) -> bool { + let success = if let Some(pending_reqs) = runtime.pending_requests.remove(&key) { + let len0 = pending_reqs.len(); + let remaining_pending_reqs: Vec<_> = pending_reqs + .into_iter() + .filter_map(|r| { + if predicate(&r) { + if let Ok(_) = match &r.bus_id { + ServiceBus::Ctl if &r.dest == &runtime.identity => runtime + .handle_rpc_ctl(endpoints, r.source.clone(), r.request.clone()), + ServiceBus::Msg if &r.dest == &runtime.identity => runtime + .handle_rpc_msg(endpoints, r.source.clone(), r.request.clone()), + _ => endpoints + .send_to( + r.bus_id.clone(), + r.source.clone(), + r.dest.clone(), + r.request.clone(), + ) + .map_err(Into::into), + } { + None + } else { + Some(r) + } + } else { + Some(r) + } + }) + .collect(); + let len1 = remaining_pending_reqs.len(); + runtime.pending_requests.insert(key, remaining_pending_reqs); + if len0 - len1 > 1 { + error!("consumed more than one request with this predicate") + } + len0 > len1 + } else { + error!("no request consumed with this predicate"); + false + }; + success } } -impl PendingRequests {} +trait PendingRequestsT { + fn defer_request(&mut self, key: ServiceId, pending_req: PendingRequest); + fn continue_deferred_requests( + runtime: &mut Runtime, // needed for recursion + endpoints: &mut Endpoints, + key: ServiceId, + predicate: fn(&PendingRequest) -> bool, + ) -> bool; +} #[derive(Debug, Clone)] pub struct PendingRequest { @@ -238,10 +293,6 @@ impl PendingRequest { request, } } - fn defer_request(self, pending_requests: &mut PendingRequests, key: ServiceId) { - let pending_reqs = pending_requests.0.entry(key).or_insert(vec![]); - pending_reqs.push(self); - } } impl StrictEncode for PendingRequest { @@ -467,56 +518,7 @@ impl Runtime { } fn pending_requests(&mut self) -> &mut HashMap> { - &mut self.pending_requests.0 - } - - fn continue_deferred_requests( - &mut self, - endpoints: &mut Endpoints, - key: ServiceId, - predicate: fn(&PendingRequest) -> bool, - ) -> bool { - let success = if let Some(pending_reqs) = self.pending_requests().remove(&key) { - let len0 = pending_reqs.len(); - let remaining_pending_reqs: Vec<_> = - pending_reqs - .into_iter() - .filter_map(|r| { - if predicate(&r) { - if let Ok(_) = match &r.bus_id { - ServiceBus::Ctl if &r.dest == &self.identity => self - .handle_rpc_ctl(endpoints, r.source.clone(), r.request.clone()), - ServiceBus::Msg if &r.dest == &self.identity => self - .handle_rpc_msg(endpoints, r.source.clone(), r.request.clone()), - _ => endpoints - .send_to( - r.bus_id.clone(), - r.source.clone(), - r.dest.clone(), - r.request.clone(), - ) - .map_err(Into::into), - } { - None - } else { - Some(r) - } - } else { - Some(r) - } - }) - .collect(); - let len1 = remaining_pending_reqs.len(); - self.pending_requests().insert(key, remaining_pending_reqs); - if len0 - len1 > 1 { - error!("consumed more than one request with this predicate") - } - len0 > len1 - } else { - error!("no request consumed with this predicate"); - false - }; - success + &mut self.pending_requests } fn state_update(&mut self, endpoints: &mut Endpoints, next_state: State) -> Result<(), Error> { @@ -631,8 +633,9 @@ impl Runtime { ServiceBus::Msg, request, ); - pending_request - .defer_request(&mut self.pending_requests, ServiceId::Wallet); + + self.pending_requests + .defer_request(ServiceId::Wallet, pending_request); } SwapRole::Alice => { debug!("Alice: forwarding reveal"); @@ -661,10 +664,8 @@ impl Runtime { &request ); let pending_req = PendingRequest::new(source, self.identity(), msg_bus, request); - pending_req.defer_request( - &mut self.pending_requests, - self.syncer_state.bitcoin_syncer(), - ); + self.pending_requests + .defer_request(self.syncer_state.bitcoin_syncer(), pending_req); } // bob and alice // store parameters from counterparty if we have not received them yet. @@ -729,7 +730,8 @@ impl Runtime { msg_bus, request, ); - pending_req.defer_request(&mut self.pending_requests, ServiceId::Wallet); + self.pending_requests + .defer_request(ServiceId::Wallet, pending_req); } _ => unreachable!( "Bob btc_fee_estimate_sat_per_kvb.is_none() was handled previously" @@ -936,7 +938,7 @@ impl Runtime { let dest = self.syncer_state.monero_syncer(); let pending_request = PendingRequest::new(self.identity(), dest.clone(), ServiceBus::Ctl, request); - pending_request.defer_request(&mut self.pending_requests, dest); + self.pending_requests.defer_request(dest, pending_request); } Request::TakeSwap(InitSwap { peerd, @@ -1055,8 +1057,11 @@ impl Runtime { .map(|reqs| reqs.len() == 2) .unwrap() => { - let success_proof = - self.continue_deferred_requests(endpoints, source.clone(), |r| { + let success_proof = PendingRequests::continue_deferred_requests( + self, + endpoints, + source.clone(), + |r| { matches!( r, &PendingRequest { @@ -1066,22 +1071,24 @@ impl Runtime { .. } ) - }); + }, + ); if !success_proof { error!("Did not dispatch proof pending request"); } - let success_params = self.continue_deferred_requests(endpoints, source, |r| { - matches!( - r, - &PendingRequest { - dest: ServiceId::Wallet, - bus_id: ServiceBus::Msg, - request: Request::Protocol(Msg::Reveal(Reveal::AliceParameters(_))), - .. - } - ) - }); + let success_params = + PendingRequests::continue_deferred_requests(self, endpoints, source, |r| { + matches!( + r, + &PendingRequest { + dest: ServiceId::Wallet, + bus_id: ServiceBus::Msg, + request: Request::Protocol(Msg::Reveal(Reveal::AliceParameters(_))), + .. + } + ) + }); if !success_params { error!("Did not dispatch params pending requests"); } @@ -1270,16 +1277,21 @@ impl Runtime { .unwrap() => { // error!("not checking tx rcvd is accordant lock"); - let success = self.continue_deferred_requests(endpoints, source, |r| { - matches!( - r, - &PendingRequest { - bus_id: ServiceBus::Msg, - request: Request::Protocol(Msg::BuyProcedureSignature(_)), - .. - } - ) - }); + let success = PendingRequests::continue_deferred_requests( + self, + endpoints, + source, + |r| { + matches!( + r, + &PendingRequest { + bus_id: ServiceBus::Msg, + request: Request::Protocol(Msg::BuyProcedureSignature(_)), + .. + } + ) + }, + ); if success { let next_state = State::Bob(BobState::BuySigB { buy_tx_seen: false, @@ -1990,19 +2002,24 @@ impl Runtime { && self.state.b_address().is_some() && self.syncer_state.btc_fee_estimate_sat_per_kvb.is_some() { - let success = self.continue_deferred_requests(endpoints, source, |i| { - matches!( - &i, - &PendingRequest { - bus_id: ServiceBus::Msg, - request: Request::Protocol(Msg::Reveal( - Reveal::AliceParameters(..) - )), - dest: ServiceId::Swap(..), - .. - } - ) - }); + let success = PendingRequests::continue_deferred_requests( + self, + endpoints, + source, + |i| { + matches!( + &i, + &PendingRequest { + bus_id: ServiceBus::Msg, + request: Request::Protocol(Msg::Reveal( + Reveal::AliceParameters(..) + )), + dest: ServiceId::Swap(..), + .. + } + ) + }, + ); if success { debug!("successfully dispatched reveal:aliceparams") } else { @@ -2258,10 +2275,8 @@ impl Runtime { ServiceBus::Msg, request, ); - pending_request.defer_request( - &mut self.pending_requests, - self.syncer_state.monero_syncer(), - ); + self.pending_requests + .defer_request(self.syncer_state.monero_syncer(), pending_request); } Request::AbortSwap @@ -2377,7 +2392,7 @@ impl Runtime { self.state = state; self.enquirer = enquirer; self.temporal_safety = temporal_safety; - self.pending_requests = PendingRequests(pending_requests); + self.pending_requests = pending_requests; self.txs = txs.clone(); trace!("Watch height bitcoin"); let watch_height_bitcoin = self.syncer_state.watch_height(Blockchain::Bitcoin);