From 7ca23d65fba5b115c98f500c0a163a293381e7fd Mon Sep 17 00:00:00 2001 From: Philip Robinson Date: Fri, 5 Jun 2020 12:15:18 +0200 Subject: [PATCH] Update Output Manager UTXO query to split into multiple queries MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously when the output manager service would send its UTXO validity query to its Base Node it would send a single query with all the UTXO hashes in it. The Base Node then responds with a single response message with all the valid UTXO’s that match a hash in the queries list. In large wallets this caused a problem where the response message could become large enough to be rejected by the frame size limits in the comms stack. In order to mitigate this issue the Output Manager service will now send queries consisting of up to a maximum number of outputs (specified in the config) and will send multiple queries until all the outputs have been queried so that any one response does not hit the frame size limit. --- .../src/output_manager_service/config.rs | 2 + .../src/output_manager_service/service.rs | 174 +++++++++++++----- .../tests/output_manager_service/service.rs | 160 ++++++++++++---- comms/dht/src/outbound/mock.rs | 8 + 4 files changed, 260 insertions(+), 84 deletions(-) diff --git a/base_layer/wallet/src/output_manager_service/config.rs b/base_layer/wallet/src/output_manager_service/config.rs index 9b7ba8c89c..9c8eecc2ba 100644 --- a/base_layer/wallet/src/output_manager_service/config.rs +++ b/base_layer/wallet/src/output_manager_service/config.rs @@ -25,12 +25,14 @@ use std::time::Duration; #[derive(Clone)] pub struct OutputManagerServiceConfig { pub base_node_query_timeout: Duration, + pub max_utxo_query_size: usize, } impl Default for OutputManagerServiceConfig { fn default() -> Self { Self { base_node_query_timeout: Duration::from_secs(30), + max_utxo_query_size: 5000, } } } diff --git a/base_layer/wallet/src/output_manager_service/service.rs b/base_layer/wallet/src/output_manager_service/service.rs index a21149a4aa..ef8a49dbff 100644 --- a/base_layer/wallet/src/output_manager_service/service.rs +++ b/base_layer/wallet/src/output_manager_service/service.rs @@ -38,7 +38,7 @@ use crate::{ use futures::{future::BoxFuture, pin_mut, stream::FuturesUnordered, FutureExt, SinkExt, Stream, StreamExt}; use log::*; use rand::{rngs::OsRng, RngCore}; -use std::{cmp::Ordering, collections::HashMap, convert::TryFrom, fmt, sync::Mutex, time::Duration}; +use std::{cmp, cmp::Ordering, collections::HashMap, convert::TryFrom, fmt, sync::Mutex, time::Duration}; use tari_broadcast_channel::Publisher; use tari_comms::types::CommsPublicKey; use tari_comms_dht::{ @@ -195,7 +195,7 @@ where let (origin_public_key, inner_msg) = msg.clone().into_origin_and_inner(); trace!(target: LOG_TARGET, "Handling Base Node Response, Trace: {}", msg.dht_header.message_tag); let result = self.handle_base_node_response(inner_msg).await.or_else(|resp| { - error!(target: LOG_TARGET, "Error handling base node service response from {}: {:?}", origin_public_key, resp); + error!(target: LOG_TARGET, "Error handling base node service response from {}: {:?}, Trace: {}", origin_public_key, resp, msg.dht_header.message_tag); Err(resp) }); @@ -290,7 +290,7 @@ where .await .map(|_| OutputManagerResponse::BaseNodePublicKeySet), OutputManagerRequest::SyncWithBaseNode => self - .query_unspent_outputs_status(utxo_query_timeout_futures) + .query_unspent_outputs_status(utxo_query_timeout_futures, None) .await .map(OutputManagerResponse::StartedBaseNodeSync), OutputManagerRequest::GetInvalidOutputs => { @@ -310,7 +310,7 @@ where } /// Handle an incoming basenode response message - pub async fn handle_base_node_response( + async fn handle_base_node_response( &mut self, response: BaseNodeProto::BaseNodeServiceResponse, ) -> Result<(), OutputManagerError> @@ -424,11 +424,11 @@ where utxo_query_timeout_futures: &mut FuturesUnordered>, ) -> Result<(), OutputManagerError> { - if self.pending_utxo_query_keys.remove(&query_key).is_some() { - error!(target: LOG_TARGET, "UTXO Query {} timed out", query_key); - self.query_unspent_outputs_status(utxo_query_timeout_futures).await?; - // TODO Remove this once this bug is fixed - trace!(target: LOG_TARGET, "Finished queueing new Base Node query timeout"); + if let Some(hashes) = self.pending_utxo_query_keys.remove(&query_key) { + warn!(target: LOG_TARGET, "UTXO Query {} timed out", query_key); + self.query_unspent_outputs_status(utxo_query_timeout_futures, Some(hashes)) + .await?; + let _ = self .event_publisher .send(OutputManagerEvent::BaseNodeSyncRequestTimedOut(query_key)) @@ -447,49 +447,122 @@ where /// Send queries to the base node to check the status of all unspent outputs. If the outputs are no longer /// available their status will be updated in the wallet. - pub async fn query_unspent_outputs_status( + async fn query_unspent_outputs_status( &mut self, utxo_query_timeout_futures: &mut FuturesUnordered>, + specified_outputs: Option>>, ) -> Result { match self.base_node_public_key.as_ref() { None => Err(OutputManagerError::NoBaseNodeKeysProvided), Some(pk) => { - let unspent_outputs: Vec = self.db.get_unspent_outputs().await?; - let mut output_hashes = Vec::new(); - for uo in unspent_outputs.iter() { - let hash = uo.hash.clone(); - output_hashes.push(hash.clone()); - } - let request_key = OsRng.next_u64(); + let mut first_request_key = 0; + let mut unspent_outputs: Vec> = if let Some(hashes) = specified_outputs { + hashes + } else { + self.db + .get_unspent_outputs() + .await? + .iter() + .map(|uo| uo.hash.clone()) + .collect() + }; - let request = BaseNodeRequestProto::FetchUtxos(BaseNodeProto::HashOutputs { - outputs: output_hashes.clone(), - }); + // Determine how many rounds of base node request we need to query all the outputs in batches of + // max_utxo_query_size + let rounds = + ((unspent_outputs.len() as f32) / (self.config.max_utxo_query_size as f32)).ceil() as usize; - let service_request = BaseNodeProto::BaseNodeServiceRequest { - request_key, - request: Some(request), - }; - // TODO Remove this once this bug is fixed - trace!(target: LOG_TARGET, "About to attempt to send query to base node"); - self.outbound_message_service - .send_direct( - pk.clone(), - OutboundEncryption::None, - OutboundDomainMessage::new(TariMessageType::BaseNodeRequest, service_request), - ) - .await?; - // TODO Remove this once this bug is fixed - trace!(target: LOG_TARGET, "Query sent to Base Node"); - self.pending_utxo_query_keys.insert(request_key, output_hashes); - let state_timeout = StateDelay::new(self.config.base_node_query_timeout, request_key); - utxo_query_timeout_futures.push(state_timeout.delay().boxed()); - debug!( - target: LOG_TARGET, - "Output Manager Sync query ({}) sent to Base Node", request_key - ); - Ok(request_key) + for r in 0..rounds { + let mut output_hashes = Vec::new(); + for uo_hash in + unspent_outputs.drain(..cmp::min(self.config.max_utxo_query_size, unspent_outputs.len())) + { + output_hashes.push(uo_hash); + } + let request_key = OsRng.next_u64(); + if first_request_key == 0 { + first_request_key = request_key; + } + + let request = BaseNodeRequestProto::FetchUtxos(BaseNodeProto::HashOutputs { + outputs: output_hashes.clone(), + }); + + let service_request = BaseNodeProto::BaseNodeServiceRequest { + request_key, + request: Some(request), + }; + + let send_message_response = self + .outbound_message_service + .send_direct( + pk.clone(), + OutboundEncryption::None, + OutboundDomainMessage::new(TariMessageType::BaseNodeRequest, service_request), + ) + .await?; + + // Here we are going to spawn a non-blocking task that will monitor and log the progress of the + // send process. + tokio::spawn(async move { + match send_message_response.resolve_ok().await { + None => trace!( + target: LOG_TARGET, + "Failed to send Output Manager UTXO Sync query ({}) to Base Node", + request_key + ), + Some(send_states) => { + if send_states.len() == 1 { + trace!( + target: LOG_TARGET, + "Output Manager UTXO Sync query ({}) queued for sending with Message {}", + request_key, + send_states[0].tag, + ); + let message_tag = send_states[0].tag; + if send_states.wait_single().await { + trace!( + target: LOG_TARGET, + "Output Manager UTXO Sync query ({}) successfully sent to Base Node with \ + Message {}", + request_key, + message_tag, + ) + } else { + trace!( + target: LOG_TARGET, + "Failed to send Output Manager UTXO Sync query ({}) to Base Node with \ + Message {}", + request_key, + message_tag, + ); + } + } else { + trace!( + target: LOG_TARGET, + "Failed to send Output Manager UTXO Sync query ({}) to Base Node", + request_key + ) + } + }, + } + }); + + self.pending_utxo_query_keys.insert(request_key, output_hashes); + let state_timeout = StateDelay::new(self.config.base_node_query_timeout, request_key); + utxo_query_timeout_futures.push(state_timeout.delay().boxed()); + debug!( + target: LOG_TARGET, + "Output Manager Sync query ({}) sent to Base Node, part {} of {} requests", + request_key, + r + 1, + rounds + ); + } + // We are just going to return the first request key for use by the front end. It is very unlikely that + // a mobile wallet will ever have this query split up + Ok(first_request_key) }, } } @@ -500,14 +573,14 @@ where Ok(self.db.add_unspent_output(output).await?) } - pub async fn get_balance(&self) -> Result { + async fn get_balance(&self) -> Result { let balance = self.db.get_balance().await?; trace!(target: LOG_TARGET, "Balance: {:?}", balance); Ok(balance) } /// Request a spending key to be used to accept a transaction from a sender. - pub async fn get_recipient_spending_key( + async fn get_recipient_spending_key( &mut self, tx_id: TxId, amount: MicroTari, @@ -634,7 +707,7 @@ where /// Confirm that a transaction has finished being negotiated between parties so the short-term encumberance can be /// made official - pub async fn confirm_encumberance(&mut self, tx_id: u64) -> Result<(), OutputManagerError> { + async fn confirm_encumberance(&mut self, tx_id: u64) -> Result<(), OutputManagerError> { self.db.confirm_encumbered_outputs(tx_id).await?; Ok(()) @@ -643,7 +716,7 @@ where /// Confirm that a received or sent transaction and its outputs have been detected on the base chain. The inputs and /// outputs are checked to see that they match what the stored PendingTransaction contains. This will /// be called by the Transaction Service which monitors the base chain. - pub async fn confirm_transaction( + async fn confirm_transaction( &mut self, tx_id: u64, inputs: &[TransactionInput], @@ -698,7 +771,7 @@ where } /// Go through the pending transaction and if any have existed longer than the specified duration, cancel them - pub async fn timeout_pending_transactions(&mut self, period: Duration) -> Result<(), OutputManagerError> { + async fn timeout_pending_transactions(&mut self, period: Duration) -> Result<(), OutputManagerError> { Ok(self.db.timeout_pending_transaction_outputs(period).await?) } @@ -793,7 +866,8 @@ where self.base_node_public_key = Some(base_node_public_key); if startup_query { - self.query_unspent_outputs_status(utxo_query_timeout_futures).await?; + self.query_unspent_outputs_status(utxo_query_timeout_futures, None) + .await?; } Ok(()) } @@ -816,7 +890,7 @@ where Ok(self.db.get_invalid_outputs().await?) } - pub async fn create_coin_split( + async fn create_coin_split( &mut self, amount_per_split: MicroTari, split_count: usize, diff --git a/base_layer/wallet/tests/output_manager_service/service.rs b/base_layer/wallet/tests/output_manager_service/service.rs index acfdacf353..54436c0736 100644 --- a/base_layer/wallet/tests/output_manager_service/service.rs +++ b/base_layer/wallet/tests/output_manager_service/service.rs @@ -42,7 +42,10 @@ use tari_comms_dht::outbound::mock::{create_outbound_service_mock, OutboundServi use tari_core::{ base_node::proto::{ base_node as BaseNodeProto, - base_node::base_node_service_response::Response as BaseNodeResponseProto, + base_node::{ + base_node_service_request::Request, + base_node_service_response::Response as BaseNodeResponseProto, + }, }, transactions::{ fee::Fee, @@ -57,7 +60,7 @@ use tari_crypto::{ commitment::HomomorphicCommitmentFactory, keys::SecretKey, range_proof::RangeProofService, - tari_utilities::ByteArray, + tari_utilities::{hash::Hashable, ByteArray}, }; use tari_p2p::domain_message::DomainMessage; use tari_service_framework::reply_channel; @@ -108,6 +111,7 @@ pub fn setup_output_manager_service( .block_on(OutputManagerService::new( OutputManagerServiceConfig { base_node_query_timeout: Duration::from_secs(10), + max_utxo_query_size: 2, }, outbound_message_requester.clone(), ts_handle.clone(), @@ -651,16 +655,30 @@ fn test_startup_utxo_scan() { let (mut oms, outbound_service, _shutdown, mut base_node_response_sender, _) = setup_output_manager_service(&mut runtime, OutputManagerMemoryDatabase::new()); + let mut hashes = Vec::new(); let key1 = PrivateKey::random(&mut OsRng); let value1 = 500; - let output1 = UnblindedOutput::new(MicroTari::from(value1), key1, None); - + let output1 = UnblindedOutput::new(MicroTari::from(value1), key1.clone(), None); + let tx_output1 = output1.as_transaction_output(&factories).unwrap(); + hashes.push(tx_output1.hash()); runtime.block_on(oms.add_output(output1.clone())).unwrap(); + let key2 = PrivateKey::random(&mut OsRng); let value2 = 800; - let output2 = UnblindedOutput::new(MicroTari::from(value2), key2, None); + let output2 = UnblindedOutput::new(MicroTari::from(value2), key2.clone(), None); + let tx_output2 = output2.as_transaction_output(&factories).unwrap(); + hashes.push(tx_output2.hash()); + runtime.block_on(oms.add_output(output2.clone())).unwrap(); + let key3 = PrivateKey::random(&mut OsRng); + let value3 = 900; + let output3 = UnblindedOutput::new(MicroTari::from(value3), key3.clone(), None); + let tx_output3 = output3.as_transaction_output(&factories).unwrap(); + hashes.push(tx_output3.hash()); + + runtime.block_on(oms.add_output(output3.clone())).unwrap(); + let base_node_identity = NodeIdentity::random( &mut OsRng, "/ip4/127.0.0.1/tcp/58217".parse().unwrap(), @@ -672,16 +690,57 @@ fn test_startup_utxo_scan() { .block_on(oms.set_base_node_public_key(base_node_identity.public_key().clone())) .unwrap(); + outbound_service + .wait_call_count(2, Duration::from_secs(60)) + .expect("call wait 1"); + let (_, body) = outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); + let bn_request1: BaseNodeProto::BaseNodeServiceRequest = envelope_body + .decode_part::(1) + .unwrap() + .unwrap(); + + match bn_request1.request { + None => assert!(false, "Invalid request"), + Some(request) => match request { + Request::FetchUtxos(hash_outputs) => { + for h in hash_outputs.outputs { + assert!(hashes.iter().find(|i| **i == h).is_some(), "Should contain hash"); + } + }, + _ => assert!(false, "invalid request"), + }, + } + + let (_, body) = outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); + let bn_request2: BaseNodeProto::BaseNodeServiceRequest = envelope_body + .decode_part::(1) + .unwrap() + .unwrap(); + + match bn_request2.request { + None => assert!(false, "Invalid request"), + Some(request) => match request { + Request::FetchUtxos(hash_outputs) => { + for h in hash_outputs.outputs { + assert!(hashes.iter().find(|i| **i == h).is_some(), "Should contain hash2"); + } + }, + _ => assert!(false, "invalid request"), + }, + } + let result_stream = runtime.block_on(async { collect_stream!( oms.get_event_stream_fused().map(|i| (*i).clone()), - take = 1, + take = 2, timeout = Duration::from_secs(60) ) }); assert_eq!( - 1, + 2, result_stream.iter().fold(0, |acc, item| { if let OutputManagerEvent::BaseNodeSyncRequestTimedOut(_) = item { acc + 1 @@ -691,10 +750,10 @@ fn test_startup_utxo_scan() { }) ); - let key3 = PrivateKey::random(&mut OsRng); - let value3 = 900; - let output3 = UnblindedOutput::new(MicroTari::from(value3), key3, None); - runtime.block_on(oms.add_output(output3.clone())).unwrap(); + let key4 = PrivateKey::random(&mut OsRng); + let value4 = 1000; + let output4 = UnblindedOutput::new(MicroTari::from(value4), key4, None); + runtime.block_on(oms.add_output(output4.clone())).unwrap(); let (_, body) = outbound_service.pop_call().unwrap(); let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); @@ -702,6 +761,7 @@ fn test_startup_utxo_scan() { .decode_part::(1) .unwrap() .unwrap(); + let _ = outbound_service.pop_call().unwrap(); let invalid_txs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); assert_eq!(invalid_txs.len(), 0); @@ -741,33 +801,44 @@ fn test_startup_utxo_scan() { ))) .unwrap(); - let result_stream = runtime.block_on(async { - collect_stream!( - oms.get_event_stream_fused().map(|i| (*i).clone()), - take = 2, - timeout = Duration::from_secs(60) - ) - }); + let mut event_stream = oms.get_event_stream_fused(); - assert_eq!( - 1, - result_stream.iter().fold(0, |acc, item| { - if let OutputManagerEvent::ReceiveBaseNodeResponse(_) = item { - acc + 1 - } else { - acc + runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut acc = 0; + loop { + futures::select! { + event = event_stream.select_next_some() => { + if let OutputManagerEvent::ReceiveBaseNodeResponse(_) = (*event).clone() { + acc += 1; + if acc >= 1 { + break; + } + } + }, + () = delay => { + break; + }, } - }) - ); + } + assert!(acc >= 1, "Did not receive enough responses"); + }); let invalid_outputs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); assert_eq!(invalid_outputs.len(), 1); - assert_eq!(invalid_outputs[0], output2); + let check2 = invalid_outputs[0] == output2; + let check3 = invalid_outputs[0] == output3; + + assert!(check2 || check3, "One of these outputs should be invalid"); let unspent_outputs = runtime.block_on(oms.get_unspent_outputs()).unwrap(); - assert_eq!(unspent_outputs.len(), 2); + assert_eq!(unspent_outputs.len(), 3); assert!(unspent_outputs.iter().find(|uo| uo == &&output1).is_some()); - assert!(unspent_outputs.iter().find(|uo| uo == &&output3).is_some()); + if check2 { + assert!(unspent_outputs.iter().find(|uo| uo == &&output3).is_some()) + } else { + assert!(unspent_outputs.iter().find(|uo| uo == &&output2).is_some()) + } runtime.block_on(oms.sync_with_base_node()).unwrap(); @@ -778,6 +849,13 @@ fn test_startup_utxo_scan() { .unwrap() .unwrap(); + let (_, body) = outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); + let bn_request2: BaseNodeProto::BaseNodeServiceRequest = envelope_body + .decode_part::(1) + .unwrap() + .unwrap(); + let invalid_txs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); assert_eq!(invalid_txs.len(), 1); @@ -787,6 +865,7 @@ fn test_startup_utxo_scan() { BaseNodeProto::TransactionOutputs { outputs: vec![].into() }, )), }; + runtime .block_on(base_node_response_sender.send(create_dummy_message( base_node_response, @@ -794,6 +873,19 @@ fn test_startup_utxo_scan() { ))) .unwrap(); + let base_node_response2 = BaseNodeProto::BaseNodeServiceResponse { + request_key: bn_request2.request_key.clone(), + response: Some(BaseNodeResponseProto::TransactionOutputs( + BaseNodeProto::TransactionOutputs { outputs: vec![].into() }, + )), + }; + runtime + .block_on(base_node_response_sender.send(create_dummy_message( + base_node_response2, + base_node_identity.public_key(), + ))) + .unwrap(); + let mut event_stream = oms.get_event_stream_fused(); runtime.block_on(async { @@ -802,9 +894,9 @@ fn test_startup_utxo_scan() { loop { futures::select! { event = event_stream.select_next_some() => { - if let OutputManagerEvent::ReceiveBaseNodeResponse(_) = (*event).clone() { + if let OutputManagerEvent::ReceiveBaseNodeResponse(r) = (*event).clone() { acc += 1; - if acc >= 2 { + if acc >= 3 { break; } } @@ -814,11 +906,11 @@ fn test_startup_utxo_scan() { }, } } - assert!(acc >= 2, "Did not receive enough responses"); + assert!(acc >= 3, "Did not receive enough responses"); }); let invalid_txs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); - assert_eq!(invalid_txs.len(), 3); + assert_eq!(invalid_txs.len(), 4); } fn sending_transaction_with_short_term_clear(backend: T) { diff --git a/comms/dht/src/outbound/mock.rs b/comms/dht/src/outbound/mock.rs index 5ef1dc96dc..c6cb7d2c1e 100644 --- a/comms/dht/src/outbound/mock.rs +++ b/comms/dht/src/outbound/mock.rs @@ -36,6 +36,7 @@ use futures::{ stream::Fuse, StreamExt, }; +use log::*; use std::{ sync::{Arc, Condvar, Mutex, RwLock}, time::Duration, @@ -43,6 +44,8 @@ use std::{ use tari_comms::message::MessageTag; use tokio::time::delay_for; +const LOG_TARGET: &str = "mock::outbound_requester"; + /// Creates a mock outbound request "handler" for testing purposes. /// /// Each time a request is expected, handle_next should be called. @@ -163,6 +166,11 @@ impl OutboundServiceMock { while let Some(req) = self.receiver.next().await { match req { DhtOutboundRequest::SendMessage(params, body, reply_tx) => { + trace!( + target: LOG_TARGET, + "Send message request received with length of {} bytes", + body.len() + ); let behaviour = self.mock_state.get_behaviour(); match (*params).clone().broadcast_strategy {