diff --git a/base_layer/wallet/src/output_manager_service/config.rs b/base_layer/wallet/src/output_manager_service/config.rs index 9b7ba8c89c..9c8eecc2ba 100644 --- a/base_layer/wallet/src/output_manager_service/config.rs +++ b/base_layer/wallet/src/output_manager_service/config.rs @@ -25,12 +25,14 @@ use std::time::Duration; #[derive(Clone)] pub struct OutputManagerServiceConfig { pub base_node_query_timeout: Duration, + pub max_utxo_query_size: usize, } impl Default for OutputManagerServiceConfig { fn default() -> Self { Self { base_node_query_timeout: Duration::from_secs(30), + max_utxo_query_size: 5000, } } } diff --git a/base_layer/wallet/src/output_manager_service/service.rs b/base_layer/wallet/src/output_manager_service/service.rs index a21149a4aa..ef8a49dbff 100644 --- a/base_layer/wallet/src/output_manager_service/service.rs +++ b/base_layer/wallet/src/output_manager_service/service.rs @@ -38,7 +38,7 @@ use crate::{ use futures::{future::BoxFuture, pin_mut, stream::FuturesUnordered, FutureExt, SinkExt, Stream, StreamExt}; use log::*; use rand::{rngs::OsRng, RngCore}; -use std::{cmp::Ordering, collections::HashMap, convert::TryFrom, fmt, sync::Mutex, time::Duration}; +use std::{cmp, cmp::Ordering, collections::HashMap, convert::TryFrom, fmt, sync::Mutex, time::Duration}; use tari_broadcast_channel::Publisher; use tari_comms::types::CommsPublicKey; use tari_comms_dht::{ @@ -195,7 +195,7 @@ where let (origin_public_key, inner_msg) = msg.clone().into_origin_and_inner(); trace!(target: LOG_TARGET, "Handling Base Node Response, Trace: {}", msg.dht_header.message_tag); let result = self.handle_base_node_response(inner_msg).await.or_else(|resp| { - error!(target: LOG_TARGET, "Error handling base node service response from {}: {:?}", origin_public_key, resp); + error!(target: LOG_TARGET, "Error handling base node service response from {}: {:?}, Trace: {}", origin_public_key, resp, msg.dht_header.message_tag); Err(resp) }); @@ -290,7 +290,7 @@ where .await .map(|_| OutputManagerResponse::BaseNodePublicKeySet), OutputManagerRequest::SyncWithBaseNode => self - .query_unspent_outputs_status(utxo_query_timeout_futures) + .query_unspent_outputs_status(utxo_query_timeout_futures, None) .await .map(OutputManagerResponse::StartedBaseNodeSync), OutputManagerRequest::GetInvalidOutputs => { @@ -310,7 +310,7 @@ where } /// Handle an incoming basenode response message - pub async fn handle_base_node_response( + async fn handle_base_node_response( &mut self, response: BaseNodeProto::BaseNodeServiceResponse, ) -> Result<(), OutputManagerError> @@ -424,11 +424,11 @@ where utxo_query_timeout_futures: &mut FuturesUnordered>, ) -> Result<(), OutputManagerError> { - if self.pending_utxo_query_keys.remove(&query_key).is_some() { - error!(target: LOG_TARGET, "UTXO Query {} timed out", query_key); - self.query_unspent_outputs_status(utxo_query_timeout_futures).await?; - // TODO Remove this once this bug is fixed - trace!(target: LOG_TARGET, "Finished queueing new Base Node query timeout"); + if let Some(hashes) = self.pending_utxo_query_keys.remove(&query_key) { + warn!(target: LOG_TARGET, "UTXO Query {} timed out", query_key); + self.query_unspent_outputs_status(utxo_query_timeout_futures, Some(hashes)) + .await?; + let _ = self .event_publisher .send(OutputManagerEvent::BaseNodeSyncRequestTimedOut(query_key)) @@ -447,49 +447,122 @@ where /// Send queries to the base node to check the status of all unspent outputs. If the outputs are no longer /// available their status will be updated in the wallet. - pub async fn query_unspent_outputs_status( + async fn query_unspent_outputs_status( &mut self, utxo_query_timeout_futures: &mut FuturesUnordered>, + specified_outputs: Option>>, ) -> Result { match self.base_node_public_key.as_ref() { None => Err(OutputManagerError::NoBaseNodeKeysProvided), Some(pk) => { - let unspent_outputs: Vec = self.db.get_unspent_outputs().await?; - let mut output_hashes = Vec::new(); - for uo in unspent_outputs.iter() { - let hash = uo.hash.clone(); - output_hashes.push(hash.clone()); - } - let request_key = OsRng.next_u64(); + let mut first_request_key = 0; + let mut unspent_outputs: Vec> = if let Some(hashes) = specified_outputs { + hashes + } else { + self.db + .get_unspent_outputs() + .await? + .iter() + .map(|uo| uo.hash.clone()) + .collect() + }; - let request = BaseNodeRequestProto::FetchUtxos(BaseNodeProto::HashOutputs { - outputs: output_hashes.clone(), - }); + // Determine how many rounds of base node request we need to query all the outputs in batches of + // max_utxo_query_size + let rounds = + ((unspent_outputs.len() as f32) / (self.config.max_utxo_query_size as f32)).ceil() as usize; - let service_request = BaseNodeProto::BaseNodeServiceRequest { - request_key, - request: Some(request), - }; - // TODO Remove this once this bug is fixed - trace!(target: LOG_TARGET, "About to attempt to send query to base node"); - self.outbound_message_service - .send_direct( - pk.clone(), - OutboundEncryption::None, - OutboundDomainMessage::new(TariMessageType::BaseNodeRequest, service_request), - ) - .await?; - // TODO Remove this once this bug is fixed - trace!(target: LOG_TARGET, "Query sent to Base Node"); - self.pending_utxo_query_keys.insert(request_key, output_hashes); - let state_timeout = StateDelay::new(self.config.base_node_query_timeout, request_key); - utxo_query_timeout_futures.push(state_timeout.delay().boxed()); - debug!( - target: LOG_TARGET, - "Output Manager Sync query ({}) sent to Base Node", request_key - ); - Ok(request_key) + for r in 0..rounds { + let mut output_hashes = Vec::new(); + for uo_hash in + unspent_outputs.drain(..cmp::min(self.config.max_utxo_query_size, unspent_outputs.len())) + { + output_hashes.push(uo_hash); + } + let request_key = OsRng.next_u64(); + if first_request_key == 0 { + first_request_key = request_key; + } + + let request = BaseNodeRequestProto::FetchUtxos(BaseNodeProto::HashOutputs { + outputs: output_hashes.clone(), + }); + + let service_request = BaseNodeProto::BaseNodeServiceRequest { + request_key, + request: Some(request), + }; + + let send_message_response = self + .outbound_message_service + .send_direct( + pk.clone(), + OutboundEncryption::None, + OutboundDomainMessage::new(TariMessageType::BaseNodeRequest, service_request), + ) + .await?; + + // Here we are going to spawn a non-blocking task that will monitor and log the progress of the + // send process. + tokio::spawn(async move { + match send_message_response.resolve_ok().await { + None => trace!( + target: LOG_TARGET, + "Failed to send Output Manager UTXO Sync query ({}) to Base Node", + request_key + ), + Some(send_states) => { + if send_states.len() == 1 { + trace!( + target: LOG_TARGET, + "Output Manager UTXO Sync query ({}) queued for sending with Message {}", + request_key, + send_states[0].tag, + ); + let message_tag = send_states[0].tag; + if send_states.wait_single().await { + trace!( + target: LOG_TARGET, + "Output Manager UTXO Sync query ({}) successfully sent to Base Node with \ + Message {}", + request_key, + message_tag, + ) + } else { + trace!( + target: LOG_TARGET, + "Failed to send Output Manager UTXO Sync query ({}) to Base Node with \ + Message {}", + request_key, + message_tag, + ); + } + } else { + trace!( + target: LOG_TARGET, + "Failed to send Output Manager UTXO Sync query ({}) to Base Node", + request_key + ) + } + }, + } + }); + + self.pending_utxo_query_keys.insert(request_key, output_hashes); + let state_timeout = StateDelay::new(self.config.base_node_query_timeout, request_key); + utxo_query_timeout_futures.push(state_timeout.delay().boxed()); + debug!( + target: LOG_TARGET, + "Output Manager Sync query ({}) sent to Base Node, part {} of {} requests", + request_key, + r + 1, + rounds + ); + } + // We are just going to return the first request key for use by the front end. It is very unlikely that + // a mobile wallet will ever have this query split up + Ok(first_request_key) }, } } @@ -500,14 +573,14 @@ where Ok(self.db.add_unspent_output(output).await?) } - pub async fn get_balance(&self) -> Result { + async fn get_balance(&self) -> Result { let balance = self.db.get_balance().await?; trace!(target: LOG_TARGET, "Balance: {:?}", balance); Ok(balance) } /// Request a spending key to be used to accept a transaction from a sender. - pub async fn get_recipient_spending_key( + async fn get_recipient_spending_key( &mut self, tx_id: TxId, amount: MicroTari, @@ -634,7 +707,7 @@ where /// Confirm that a transaction has finished being negotiated between parties so the short-term encumberance can be /// made official - pub async fn confirm_encumberance(&mut self, tx_id: u64) -> Result<(), OutputManagerError> { + async fn confirm_encumberance(&mut self, tx_id: u64) -> Result<(), OutputManagerError> { self.db.confirm_encumbered_outputs(tx_id).await?; Ok(()) @@ -643,7 +716,7 @@ where /// Confirm that a received or sent transaction and its outputs have been detected on the base chain. The inputs and /// outputs are checked to see that they match what the stored PendingTransaction contains. This will /// be called by the Transaction Service which monitors the base chain. - pub async fn confirm_transaction( + async fn confirm_transaction( &mut self, tx_id: u64, inputs: &[TransactionInput], @@ -698,7 +771,7 @@ where } /// Go through the pending transaction and if any have existed longer than the specified duration, cancel them - pub async fn timeout_pending_transactions(&mut self, period: Duration) -> Result<(), OutputManagerError> { + async fn timeout_pending_transactions(&mut self, period: Duration) -> Result<(), OutputManagerError> { Ok(self.db.timeout_pending_transaction_outputs(period).await?) } @@ -793,7 +866,8 @@ where self.base_node_public_key = Some(base_node_public_key); if startup_query { - self.query_unspent_outputs_status(utxo_query_timeout_futures).await?; + self.query_unspent_outputs_status(utxo_query_timeout_futures, None) + .await?; } Ok(()) } @@ -816,7 +890,7 @@ where Ok(self.db.get_invalid_outputs().await?) } - pub async fn create_coin_split( + async fn create_coin_split( &mut self, amount_per_split: MicroTari, split_count: usize, diff --git a/base_layer/wallet/tests/output_manager_service/service.rs b/base_layer/wallet/tests/output_manager_service/service.rs index acfdacf353..54436c0736 100644 --- a/base_layer/wallet/tests/output_manager_service/service.rs +++ b/base_layer/wallet/tests/output_manager_service/service.rs @@ -42,7 +42,10 @@ use tari_comms_dht::outbound::mock::{create_outbound_service_mock, OutboundServi use tari_core::{ base_node::proto::{ base_node as BaseNodeProto, - base_node::base_node_service_response::Response as BaseNodeResponseProto, + base_node::{ + base_node_service_request::Request, + base_node_service_response::Response as BaseNodeResponseProto, + }, }, transactions::{ fee::Fee, @@ -57,7 +60,7 @@ use tari_crypto::{ commitment::HomomorphicCommitmentFactory, keys::SecretKey, range_proof::RangeProofService, - tari_utilities::ByteArray, + tari_utilities::{hash::Hashable, ByteArray}, }; use tari_p2p::domain_message::DomainMessage; use tari_service_framework::reply_channel; @@ -108,6 +111,7 @@ pub fn setup_output_manager_service( .block_on(OutputManagerService::new( OutputManagerServiceConfig { base_node_query_timeout: Duration::from_secs(10), + max_utxo_query_size: 2, }, outbound_message_requester.clone(), ts_handle.clone(), @@ -651,16 +655,30 @@ fn test_startup_utxo_scan() { let (mut oms, outbound_service, _shutdown, mut base_node_response_sender, _) = setup_output_manager_service(&mut runtime, OutputManagerMemoryDatabase::new()); + let mut hashes = Vec::new(); let key1 = PrivateKey::random(&mut OsRng); let value1 = 500; - let output1 = UnblindedOutput::new(MicroTari::from(value1), key1, None); - + let output1 = UnblindedOutput::new(MicroTari::from(value1), key1.clone(), None); + let tx_output1 = output1.as_transaction_output(&factories).unwrap(); + hashes.push(tx_output1.hash()); runtime.block_on(oms.add_output(output1.clone())).unwrap(); + let key2 = PrivateKey::random(&mut OsRng); let value2 = 800; - let output2 = UnblindedOutput::new(MicroTari::from(value2), key2, None); + let output2 = UnblindedOutput::new(MicroTari::from(value2), key2.clone(), None); + let tx_output2 = output2.as_transaction_output(&factories).unwrap(); + hashes.push(tx_output2.hash()); + runtime.block_on(oms.add_output(output2.clone())).unwrap(); + let key3 = PrivateKey::random(&mut OsRng); + let value3 = 900; + let output3 = UnblindedOutput::new(MicroTari::from(value3), key3.clone(), None); + let tx_output3 = output3.as_transaction_output(&factories).unwrap(); + hashes.push(tx_output3.hash()); + + runtime.block_on(oms.add_output(output3.clone())).unwrap(); + let base_node_identity = NodeIdentity::random( &mut OsRng, "/ip4/127.0.0.1/tcp/58217".parse().unwrap(), @@ -672,16 +690,57 @@ fn test_startup_utxo_scan() { .block_on(oms.set_base_node_public_key(base_node_identity.public_key().clone())) .unwrap(); + outbound_service + .wait_call_count(2, Duration::from_secs(60)) + .expect("call wait 1"); + let (_, body) = outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); + let bn_request1: BaseNodeProto::BaseNodeServiceRequest = envelope_body + .decode_part::(1) + .unwrap() + .unwrap(); + + match bn_request1.request { + None => assert!(false, "Invalid request"), + Some(request) => match request { + Request::FetchUtxos(hash_outputs) => { + for h in hash_outputs.outputs { + assert!(hashes.iter().find(|i| **i == h).is_some(), "Should contain hash"); + } + }, + _ => assert!(false, "invalid request"), + }, + } + + let (_, body) = outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); + let bn_request2: BaseNodeProto::BaseNodeServiceRequest = envelope_body + .decode_part::(1) + .unwrap() + .unwrap(); + + match bn_request2.request { + None => assert!(false, "Invalid request"), + Some(request) => match request { + Request::FetchUtxos(hash_outputs) => { + for h in hash_outputs.outputs { + assert!(hashes.iter().find(|i| **i == h).is_some(), "Should contain hash2"); + } + }, + _ => assert!(false, "invalid request"), + }, + } + let result_stream = runtime.block_on(async { collect_stream!( oms.get_event_stream_fused().map(|i| (*i).clone()), - take = 1, + take = 2, timeout = Duration::from_secs(60) ) }); assert_eq!( - 1, + 2, result_stream.iter().fold(0, |acc, item| { if let OutputManagerEvent::BaseNodeSyncRequestTimedOut(_) = item { acc + 1 @@ -691,10 +750,10 @@ fn test_startup_utxo_scan() { }) ); - let key3 = PrivateKey::random(&mut OsRng); - let value3 = 900; - let output3 = UnblindedOutput::new(MicroTari::from(value3), key3, None); - runtime.block_on(oms.add_output(output3.clone())).unwrap(); + let key4 = PrivateKey::random(&mut OsRng); + let value4 = 1000; + let output4 = UnblindedOutput::new(MicroTari::from(value4), key4, None); + runtime.block_on(oms.add_output(output4.clone())).unwrap(); let (_, body) = outbound_service.pop_call().unwrap(); let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); @@ -702,6 +761,7 @@ fn test_startup_utxo_scan() { .decode_part::(1) .unwrap() .unwrap(); + let _ = outbound_service.pop_call().unwrap(); let invalid_txs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); assert_eq!(invalid_txs.len(), 0); @@ -741,33 +801,44 @@ fn test_startup_utxo_scan() { ))) .unwrap(); - let result_stream = runtime.block_on(async { - collect_stream!( - oms.get_event_stream_fused().map(|i| (*i).clone()), - take = 2, - timeout = Duration::from_secs(60) - ) - }); + let mut event_stream = oms.get_event_stream_fused(); - assert_eq!( - 1, - result_stream.iter().fold(0, |acc, item| { - if let OutputManagerEvent::ReceiveBaseNodeResponse(_) = item { - acc + 1 - } else { - acc + runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut acc = 0; + loop { + futures::select! { + event = event_stream.select_next_some() => { + if let OutputManagerEvent::ReceiveBaseNodeResponse(_) = (*event).clone() { + acc += 1; + if acc >= 1 { + break; + } + } + }, + () = delay => { + break; + }, } - }) - ); + } + assert!(acc >= 1, "Did not receive enough responses"); + }); let invalid_outputs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); assert_eq!(invalid_outputs.len(), 1); - assert_eq!(invalid_outputs[0], output2); + let check2 = invalid_outputs[0] == output2; + let check3 = invalid_outputs[0] == output3; + + assert!(check2 || check3, "One of these outputs should be invalid"); let unspent_outputs = runtime.block_on(oms.get_unspent_outputs()).unwrap(); - assert_eq!(unspent_outputs.len(), 2); + assert_eq!(unspent_outputs.len(), 3); assert!(unspent_outputs.iter().find(|uo| uo == &&output1).is_some()); - assert!(unspent_outputs.iter().find(|uo| uo == &&output3).is_some()); + if check2 { + assert!(unspent_outputs.iter().find(|uo| uo == &&output3).is_some()) + } else { + assert!(unspent_outputs.iter().find(|uo| uo == &&output2).is_some()) + } runtime.block_on(oms.sync_with_base_node()).unwrap(); @@ -778,6 +849,13 @@ fn test_startup_utxo_scan() { .unwrap() .unwrap(); + let (_, body) = outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); + let bn_request2: BaseNodeProto::BaseNodeServiceRequest = envelope_body + .decode_part::(1) + .unwrap() + .unwrap(); + let invalid_txs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); assert_eq!(invalid_txs.len(), 1); @@ -787,6 +865,7 @@ fn test_startup_utxo_scan() { BaseNodeProto::TransactionOutputs { outputs: vec![].into() }, )), }; + runtime .block_on(base_node_response_sender.send(create_dummy_message( base_node_response, @@ -794,6 +873,19 @@ fn test_startup_utxo_scan() { ))) .unwrap(); + let base_node_response2 = BaseNodeProto::BaseNodeServiceResponse { + request_key: bn_request2.request_key.clone(), + response: Some(BaseNodeResponseProto::TransactionOutputs( + BaseNodeProto::TransactionOutputs { outputs: vec![].into() }, + )), + }; + runtime + .block_on(base_node_response_sender.send(create_dummy_message( + base_node_response2, + base_node_identity.public_key(), + ))) + .unwrap(); + let mut event_stream = oms.get_event_stream_fused(); runtime.block_on(async { @@ -802,9 +894,9 @@ fn test_startup_utxo_scan() { loop { futures::select! { event = event_stream.select_next_some() => { - if let OutputManagerEvent::ReceiveBaseNodeResponse(_) = (*event).clone() { + if let OutputManagerEvent::ReceiveBaseNodeResponse(r) = (*event).clone() { acc += 1; - if acc >= 2 { + if acc >= 3 { break; } } @@ -814,11 +906,11 @@ fn test_startup_utxo_scan() { }, } } - assert!(acc >= 2, "Did not receive enough responses"); + assert!(acc >= 3, "Did not receive enough responses"); }); let invalid_txs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); - assert_eq!(invalid_txs.len(), 3); + assert_eq!(invalid_txs.len(), 4); } fn sending_transaction_with_short_term_clear(backend: T) { diff --git a/comms/dht/src/outbound/mock.rs b/comms/dht/src/outbound/mock.rs index 5ef1dc96dc..c6cb7d2c1e 100644 --- a/comms/dht/src/outbound/mock.rs +++ b/comms/dht/src/outbound/mock.rs @@ -36,6 +36,7 @@ use futures::{ stream::Fuse, StreamExt, }; +use log::*; use std::{ sync::{Arc, Condvar, Mutex, RwLock}, time::Duration, @@ -43,6 +44,8 @@ use std::{ use tari_comms::message::MessageTag; use tokio::time::delay_for; +const LOG_TARGET: &str = "mock::outbound_requester"; + /// Creates a mock outbound request "handler" for testing purposes. /// /// Each time a request is expected, handle_next should be called. @@ -163,6 +166,11 @@ impl OutboundServiceMock { while let Some(req) = self.receiver.next().await { match req { DhtOutboundRequest::SendMessage(params, body, reply_tx) => { + trace!( + target: LOG_TARGET, + "Send message request received with length of {} bytes", + body.len() + ); let behaviour = self.mock_state.get_behaviour(); match (*params).clone().broadcast_strategy {