diff --git a/base_layer/wallet/src/output_manager_service/config.rs b/base_layer/wallet/src/output_manager_service/config.rs index 9b7ba8c89c..9c8eecc2ba 100644 --- a/base_layer/wallet/src/output_manager_service/config.rs +++ b/base_layer/wallet/src/output_manager_service/config.rs @@ -25,12 +25,14 @@ use std::time::Duration; #[derive(Clone)] pub struct OutputManagerServiceConfig { pub base_node_query_timeout: Duration, + pub max_utxo_query_size: usize, } impl Default for OutputManagerServiceConfig { fn default() -> Self { Self { base_node_query_timeout: Duration::from_secs(30), + max_utxo_query_size: 5000, } } } diff --git a/base_layer/wallet/src/output_manager_service/service.rs b/base_layer/wallet/src/output_manager_service/service.rs index a21149a4aa..f94c59c2ee 100644 --- a/base_layer/wallet/src/output_manager_service/service.rs +++ b/base_layer/wallet/src/output_manager_service/service.rs @@ -38,7 +38,7 @@ use crate::{ use futures::{future::BoxFuture, pin_mut, stream::FuturesUnordered, FutureExt, SinkExt, Stream, StreamExt}; use log::*; use rand::{rngs::OsRng, RngCore}; -use std::{cmp::Ordering, collections::HashMap, convert::TryFrom, fmt, sync::Mutex, time::Duration}; +use std::{cmp, cmp::Ordering, collections::HashMap, convert::TryFrom, fmt, sync::Mutex, time::Duration}; use tari_broadcast_channel::Publisher; use tari_comms::types::CommsPublicKey; use tari_comms_dht::{ @@ -99,6 +99,7 @@ where TBackend: OutputManagerBackend + 'static factories: CryptoFactories, base_node_public_key: Option, pending_utxo_query_keys: HashMap>>, + pending_revalidation_query_keys: HashMap>>, event_publisher: Publisher, } @@ -155,6 +156,7 @@ where factories, base_node_public_key: None, pending_utxo_query_keys: HashMap::new(), + pending_revalidation_query_keys: HashMap::new(), event_publisher, }) } @@ -195,7 +197,7 @@ where let (origin_public_key, inner_msg) = msg.clone().into_origin_and_inner(); trace!(target: LOG_TARGET, "Handling Base Node Response, Trace: {}", msg.dht_header.message_tag); let result = self.handle_base_node_response(inner_msg).await.or_else(|resp| { - error!(target: LOG_TARGET, "Error handling base node service response from {}: {:?}", origin_public_key, resp); + error!(target: LOG_TARGET, "Error handling base node service response from {}: {:?}, Trace: {}", origin_public_key, resp, msg.dht_header.message_tag); Err(resp) }); @@ -324,86 +326,153 @@ where }, }; - // Only process requests with a request_key that we are expecting. - let queried_hashes: Vec> = match self.pending_utxo_query_keys.remove(&request_key) { - None => { - trace!( - target: LOG_TARGET, - "Ignoring Base Node Response with unexpected request key ({}), it was not meant for this service.", - request_key - ); - return Ok(()); - }, - Some(qh) => qh, - }; + let mut unspent_query_handled = false; + let mut invalid_query_handled = false; - trace!( - target: LOG_TARGET, - "Handling a Base Node Response meant for this service" - ); + // Check if the received key is in the pending UTXO query list to be handled + if let Some(queried_hashes) = self.pending_utxo_query_keys.remove(&request_key) { + trace!( + target: LOG_TARGET, + "Handling a Base Node Response for a Unspent Outputs request ({})", + request_key + ); - // Construct a HashMap of all the unspent outputs - let unspent_outputs: Vec = self.db.get_unspent_outputs().await?; + // Construct a HashMap of all the unspent outputs + let unspent_outputs: Vec = self.db.get_unspent_outputs().await?; - let mut output_hashes = HashMap::new(); - for uo in unspent_outputs.iter() { - let hash = uo.hash.clone(); - if queried_hashes.iter().any(|h| &hash == h) { - output_hashes.insert(hash.clone(), uo.clone()); + let mut output_hashes = HashMap::new(); + for uo in unspent_outputs.iter() { + let hash = uo.hash.clone(); + if queried_hashes.iter().any(|h| &hash == h) { + output_hashes.insert(hash.clone(), uo.clone()); + } } - } - // Go through all the returned UTXOs and if they are in the hashmap remove them - for output in response.iter() { - let response_hash = TransactionOutput::try_from(output.clone()) - .map_err(OutputManagerError::ConversionError)? - .hash(); + // Go through all the returned UTXOs and if they are in the hashmap remove them + for output in response.iter() { + let response_hash = TransactionOutput::try_from(output.clone()) + .map_err(OutputManagerError::ConversionError)? + .hash(); - let _ = output_hashes.remove(&response_hash); - } + let _ = output_hashes.remove(&response_hash); + } - // If there are any remaining Unspent Outputs we will move them to the invalid collection - for (_k, v) in output_hashes { - // Get the transaction these belonged to so we can display the kernel signature of the transaction this - // output belonged to. + // If there are any remaining Unspent Outputs we will move them to the invalid collection + for (_k, v) in output_hashes { + // Get the transaction these belonged to so we can display the kernel signature of the transaction + // this output belonged to. - warn!( - target: LOG_TARGET, - "Output with value {} not returned from Base Node query and is thus being invalidated", - v.unblinded_output.value - ); - // If the output that is being invalidated has an associated TxId then get the kernel signature of the - // transaction and display for easier debugging - if let Some(tx_id) = self.db.invalidate_output(v).await? { - if let Ok(transaction) = self.transaction_service.get_completed_transaction(tx_id).await { + warn!( + target: LOG_TARGET, + "Output with value {} not returned from Base Node query and is thus being invalidated", + v.unblinded_output.value + ); + // If the output that is being invalidated has an associated TxId then get the kernel signature of + // the transaction and display for easier debugging + if let Some(tx_id) = self.db.invalidate_output(v).await? { + if let Ok(transaction) = self.transaction_service.get_completed_transaction(tx_id).await { + info!( + target: LOG_TARGET, + "Invalidated Output is from Transaction (TxId: {}) with message: {} and Kernel Signature: \ + {}", + transaction.tx_id, + transaction.message, + transaction.transaction.body.kernels()[0] + .excess_sig + .get_signature() + .to_hex() + ) + } + } else { info!( target: LOG_TARGET, - "Invalidated Output is from Transaction (TxId: {}) with message: {} and Kernel Signature: {}", - transaction.tx_id, - transaction.message, - transaction.transaction.body.kernels()[0] - .excess_sig - .get_signature() - .to_hex() - ) + "Invalidated Output does not have an associated TxId so it is likely a Coinbase output lost \ + to a Re-Org" + ); } - } else { - info!( - target: LOG_TARGET, - "Invalidated Output does not have an associated TxId so it is likely a Coinbase output lost to a \ - Re-Org" - ); } + unspent_query_handled = true; + debug!( + target: LOG_TARGET, + "Handled Base Node response for Unspent Outputs Query {}", request_key + ); + }; + + // Check if the received key is in the Invalid UTXO query list waiting to be handled + if let Some(_) = self.pending_revalidation_query_keys.remove(&request_key) { + trace!( + target: LOG_TARGET, + "Handling a Base Node Response for a Invalid Outputs request ({})", + request_key + ); + let invalid_outputs = self.db.get_invalid_outputs().await?; + + for output in response.iter() { + let response_hash = TransactionOutput::try_from(output.clone()) + .map_err(OutputManagerError::ConversionError)? + .hash(); + + if let Some(output) = invalid_outputs.iter().find(|o| o.hash == response_hash) { + if let Ok(_) = self + .db + .revalidate_output(output.unblinded_output.spending_key.clone()) + .await + { + trace!( + target: LOG_TARGET, + "Output with value {} has been restored to a valid spendable output", + output.unblinded_output.value + ); + } + } + } + invalid_query_handled = true; + debug!( + target: LOG_TARGET, + "Handled Base Node response for Invalid Outputs Query {}", request_key + ); } - debug!( - target: LOG_TARGET, - "Handled Base Node response for Query {}", request_key - ); + if unspent_query_handled || invalid_query_handled { + let _ = self + .event_publisher + .send(OutputManagerEvent::ReceiveBaseNodeResponse(request_key)) + .await + .map_err(|e| { + trace!( + target: LOG_TARGET, + "Error sending event, usually because there are no subscribers: {:?}", + e + ); + e + }); + } + + Ok(()) + } + + /// Handle the timeout of a pending UTXO query. + pub async fn handle_utxo_query_timeout( + &mut self, + query_key: u64, + utxo_query_timeout_futures: &mut FuturesUnordered>, + ) -> Result<(), OutputManagerError> + { + if let Some(hashes) = self.pending_utxo_query_keys.remove(&query_key) { + warn!(target: LOG_TARGET, "UTXO Unspent Outputs Query {} timed out", query_key); + self.query_outputs_status(utxo_query_timeout_futures, hashes, UtxoQueryType::UnspentOutputs) + .await?; + } + + if let Some(hashes) = self.pending_revalidation_query_keys.remove(&query_key) { + warn!(target: LOG_TARGET, "UTXO Invalid Outputs Query {} timed out", query_key); + self.query_outputs_status(utxo_query_timeout_futures, hashes, UtxoQueryType::InvalidOutputs) + .await?; + } let _ = self .event_publisher - .send(OutputManagerEvent::ReceiveBaseNodeResponse(request_key)) + .send(OutputManagerEvent::BaseNodeSyncRequestTimedOut(query_key)) .await .map_err(|e| { trace!( @@ -413,83 +482,178 @@ where ); e }); - Ok(()) } - /// Handle the timeout of a pending UTXO query. - pub async fn handle_utxo_query_timeout( + pub async fn query_unspent_outputs_status( &mut self, - query_key: u64, utxo_query_timeout_futures: &mut FuturesUnordered>, - ) -> Result<(), OutputManagerError> + ) -> Result { - if self.pending_utxo_query_keys.remove(&query_key).is_some() { - error!(target: LOG_TARGET, "UTXO Query {} timed out", query_key); - self.query_unspent_outputs_status(utxo_query_timeout_futures).await?; - // TODO Remove this once this bug is fixed - trace!(target: LOG_TARGET, "Finished queueing new Base Node query timeout"); - let _ = self - .event_publisher - .send(OutputManagerEvent::BaseNodeSyncRequestTimedOut(query_key)) - .await - .map_err(|e| { - trace!( - target: LOG_TARGET, - "Error sending event, usually because there are no subscribers: {:?}", - e - ); - e - }); + let unspent_output_hashes = self + .db + .get_unspent_outputs() + .await? + .iter() + .map(|uo| uo.hash.clone()) + .collect(); + + let key = self + .query_outputs_status( + utxo_query_timeout_futures, + unspent_output_hashes, + UtxoQueryType::UnspentOutputs, + ) + .await?; + + Ok(key) + } + + pub async fn query_invalid_outputs_status( + &mut self, + utxo_query_timeout_futures: &mut FuturesUnordered>, + ) -> Result + { + let invalid_output_hashes: Vec> = self + .db + .get_invalid_outputs() + .await? + .iter() + .map(|uo| uo.hash.clone()) + .collect(); + + let mut key = 0; + if !invalid_output_hashes.is_empty() { + key = self + .query_outputs_status( + utxo_query_timeout_futures, + invalid_output_hashes, + UtxoQueryType::InvalidOutputs, + ) + .await?; } - Ok(()) + Ok(key) } - /// Send queries to the base node to check the status of all unspent outputs. If the outputs are no longer - /// available their status will be updated in the wallet. - pub async fn query_unspent_outputs_status( + /// Send queries to the base node to check the status of all specified outputs. + async fn query_outputs_status( &mut self, utxo_query_timeout_futures: &mut FuturesUnordered>, + mut outputs_to_query: Vec>, + query_type: UtxoQueryType, ) -> Result { match self.base_node_public_key.as_ref() { None => Err(OutputManagerError::NoBaseNodeKeysProvided), Some(pk) => { - let unspent_outputs: Vec = self.db.get_unspent_outputs().await?; - let mut output_hashes = Vec::new(); - for uo in unspent_outputs.iter() { - let hash = uo.hash.clone(); - output_hashes.push(hash.clone()); - } - let request_key = OsRng.next_u64(); + let mut first_request_key = 0; - let request = BaseNodeRequestProto::FetchUtxos(BaseNodeProto::HashOutputs { - outputs: output_hashes.clone(), - }); + // Determine how many rounds of base node request we need to query all the outputs in batches of + // max_utxo_query_size + let rounds = + ((outputs_to_query.len() as f32) / (self.config.max_utxo_query_size as f32 + 0.1)) as usize + 1; - let service_request = BaseNodeProto::BaseNodeServiceRequest { - request_key, - request: Some(request), - }; - // TODO Remove this once this bug is fixed - trace!(target: LOG_TARGET, "About to attempt to send query to base node"); - self.outbound_message_service - .send_direct( - pk.clone(), - OutboundEncryption::None, - OutboundDomainMessage::new(TariMessageType::BaseNodeRequest, service_request), - ) - .await?; - // TODO Remove this once this bug is fixed - trace!(target: LOG_TARGET, "Query sent to Base Node"); - self.pending_utxo_query_keys.insert(request_key, output_hashes); - let state_timeout = StateDelay::new(self.config.base_node_query_timeout, request_key); - utxo_query_timeout_futures.push(state_timeout.delay().boxed()); - debug!( - target: LOG_TARGET, - "Output Manager Sync query ({}) sent to Base Node", request_key - ); - Ok(request_key) + for r in 0..rounds { + let mut output_hashes = Vec::new(); + for uo_hash in + outputs_to_query.drain(..cmp::min(self.config.max_utxo_query_size, outputs_to_query.len())) + { + output_hashes.push(uo_hash); + } + let request_key = OsRng.next_u64(); + if first_request_key == 0 { + first_request_key = request_key; + } + + let request = BaseNodeRequestProto::FetchUtxos(BaseNodeProto::HashOutputs { + outputs: output_hashes.clone(), + }); + + let service_request = BaseNodeProto::BaseNodeServiceRequest { + request_key, + request: Some(request), + }; + + let send_message_response = self + .outbound_message_service + .send_direct( + pk.clone(), + OutboundEncryption::None, + OutboundDomainMessage::new(TariMessageType::BaseNodeRequest, service_request), + ) + .await?; + + // Here we are going to spawn a non-blocking task that will monitor and log the progress of the + // send process. + let owned_request_key = request_key; + tokio::spawn(async move { + match send_message_response.resolve_ok().await { + None => trace!( + target: LOG_TARGET, + "Failed to send Output Manager UTXO query ({}) to Base Node", + owned_request_key + ), + Some(send_states) => { + if send_states.len() == 1 { + trace!( + target: LOG_TARGET, + "Output Manager UTXO query ({}) queued for sending with Message {}", + owned_request_key, + send_states[0].tag, + ); + let message_tag = send_states[0].tag; + if send_states.wait_single().await { + trace!( + target: LOG_TARGET, + "Output Manager UTXO query ({}) successfully sent to Base Node with \ + Message {}", + owned_request_key, + message_tag, + ) + } else { + trace!( + target: LOG_TARGET, + "Failed to send Output Manager UTXO query ({}) to Base Node with Message \ + {}", + owned_request_key, + message_tag, + ); + } + } else { + trace!( + target: LOG_TARGET, + "Failed to send Output Manager UTXO query ({}) to Base Node", + owned_request_key + ) + } + }, + } + }); + + match query_type { + UtxoQueryType::UnspentOutputs => { + self.pending_utxo_query_keys.insert(request_key, output_hashes); + }, + UtxoQueryType::InvalidOutputs => { + self.pending_revalidation_query_keys.insert(request_key, output_hashes); + }, + } + + let state_timeout = StateDelay::new(self.config.base_node_query_timeout, request_key); + utxo_query_timeout_futures.push(state_timeout.delay().boxed()); + + debug!( + target: LOG_TARGET, + "Output Manager {} query ({}) sent to Base Node, part {} of {} requests", + query_type, + request_key, + r + 1, + rounds + ); + } + // We are just going to return the first request key for use by the front end. It is very unlikely that + // a mobile wallet will ever have this query split up + Ok(first_request_key) }, } } @@ -794,6 +958,7 @@ where if startup_query { self.query_unspent_outputs_status(utxo_query_timeout_futures).await?; + self.query_invalid_outputs_status(utxo_query_timeout_futures).await?; } Ok(()) } @@ -952,3 +1117,17 @@ impl fmt::Display for Balance { Ok(()) } } + +enum UtxoQueryType { + UnspentOutputs, + InvalidOutputs, +} + +impl fmt::Display for UtxoQueryType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + UtxoQueryType::UnspentOutputs => write!(f, "Unspent Outputs"), + UtxoQueryType::InvalidOutputs => write!(f, "Invalid Outputs"), + } + } +} diff --git a/base_layer/wallet/src/output_manager_service/storage/database.rs b/base_layer/wallet/src/output_manager_service/storage/database.rs index 0c3a35be1e..eaea0f0b45 100644 --- a/base_layer/wallet/src/output_manager_service/storage/database.rs +++ b/base_layer/wallet/src/output_manager_service/storage/database.rs @@ -83,6 +83,8 @@ pub trait OutputManagerBackend: Send + Sync { /// If an unspent output is detected as invalid (i.e. not available on the blockchain) then it should be moved to /// the invalid outputs collection. The function will return the last recorded TxId associated with this output. fn invalidate_unspent_output(&self, output: &DbUnblindedOutput) -> Result, OutputManagerStorageError>; + /// If an invalid output is found to be valid this function will turn it back into an unspent output + fn revalidate_unspent_output(&self, spending_key: &BlindingFactor) -> Result<(), OutputManagerStorageError>; } /// Holds the outputs that have been selected for a given pending transaction waiting for confirmation @@ -502,6 +504,14 @@ where T: OutputManagerBackend + 'static .or_else(|err| Err(OutputManagerStorageError::BlockingTaskSpawnError(err.to_string()))) .and_then(|inner_result| inner_result) } + + pub async fn revalidate_output(&self, spending_key: BlindingFactor) -> Result<(), OutputManagerStorageError> { + let db_clone = self.db.clone(); + tokio::task::spawn_blocking(move || db_clone.revalidate_unspent_output(&spending_key)) + .await + .or_else(|err| Err(OutputManagerStorageError::BlockingTaskSpawnError(err.to_string()))) + .and_then(|inner_result| inner_result) + } } fn unexpected_result(req: DbKey, res: DbValue) -> Result { diff --git a/base_layer/wallet/src/output_manager_service/storage/memory_db.rs b/base_layer/wallet/src/output_manager_service/storage/memory_db.rs index 3b6c5fe2c0..5c43865148 100644 --- a/base_layer/wallet/src/output_manager_service/storage/memory_db.rs +++ b/base_layer/wallet/src/output_manager_service/storage/memory_db.rs @@ -42,6 +42,7 @@ use std::{ sync::{Arc, RwLock}, time::Duration, }; +use tari_core::transactions::types::BlindingFactor; /// This structure is an In-Memory database backend that implements the `OutputManagerBackend` trait and provides all /// the functionality required by the trait. @@ -372,6 +373,22 @@ impl OutputManagerBackend for OutputManagerMemoryDatabase { None => Err(OutputManagerStorageError::ValuesNotFound), } } + + fn revalidate_unspent_output(&self, spending_key: &BlindingFactor) -> Result<(), OutputManagerStorageError> { + let mut db = acquire_write_lock!(self.db); + match db + .invalid_outputs + .iter() + .position(|v| v.output.unblinded_output.spending_key == *spending_key) + { + Some(pos) => { + let output = db.invalid_outputs.remove(pos); + db.unspent_outputs.push(output.clone()); + Ok(()) + }, + None => Err(OutputManagerStorageError::ValuesNotFound), + } + } } // A struct that contains the extra info we are using in the Sql version of this backend diff --git a/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs b/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs index c2ea85e7a6..363773a557 100644 --- a/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs +++ b/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs @@ -54,7 +54,7 @@ use tari_core::{ transactions::{ tari_amount::MicroTari, transaction::{OutputFeatures, OutputFlags, UnblindedOutput}, - types::{CryptoFactories, PrivateKey}, + types::{BlindingFactor, CryptoFactories, PrivateKey}, }, }; use tari_crypto::tari_utilities::ByteArray; @@ -429,6 +429,23 @@ impl OutputManagerBackend for OutputManagerSqliteDatabase { Ok(tx_id) } + + fn revalidate_unspent_output(&self, spending_key: &BlindingFactor) -> Result<(), OutputManagerStorageError> { + let conn = acquire_lock!(self.database_connection); + let output = OutputSql::find(&spending_key.to_vec(), &conn)?; + + if OutputStatus::try_from(output.status)? != OutputStatus::Invalid { + return Err(OutputManagerStorageError::ValuesNotFound); + } + let _ = output.update( + UpdateOutput { + status: Some(OutputStatus::Unspent), + tx_id: None, + }, + &(*conn), + )?; + Ok(()) + } } /// A utility function to construct a PendingTransactionOutputs structure for a TxId, set of Outputs and a Timestamp diff --git a/base_layer/wallet/tests/output_manager_service/service.rs b/base_layer/wallet/tests/output_manager_service/service.rs index acfdacf353..00f84bbb95 100644 --- a/base_layer/wallet/tests/output_manager_service/service.rs +++ b/base_layer/wallet/tests/output_manager_service/service.rs @@ -42,7 +42,10 @@ use tari_comms_dht::outbound::mock::{create_outbound_service_mock, OutboundServi use tari_core::{ base_node::proto::{ base_node as BaseNodeProto, - base_node::base_node_service_response::Response as BaseNodeResponseProto, + base_node::{ + base_node_service_request::Request, + base_node_service_response::Response as BaseNodeResponseProto, + }, }, transactions::{ fee::Fee, @@ -57,7 +60,7 @@ use tari_crypto::{ commitment::HomomorphicCommitmentFactory, keys::SecretKey, range_proof::RangeProofService, - tari_utilities::ByteArray, + tari_utilities::{hash::Hashable, ByteArray}, }; use tari_p2p::domain_message::DomainMessage; use tari_service_framework::reply_channel; @@ -70,8 +73,9 @@ use tari_wallet::{ handle::{OutputManagerEvent, OutputManagerHandle}, service::OutputManagerService, storage::{ - database::{DbKey, DbValue, OutputManagerBackend, OutputManagerDatabase}, + database::{DbKey, DbKeyValuePair, DbValue, OutputManagerBackend, OutputManagerDatabase, WriteOperation}, memory_db::OutputManagerMemoryDatabase, + models::DbUnblindedOutput, sqlite_db::OutputManagerSqliteDatabase, }, }, @@ -108,6 +112,7 @@ pub fn setup_output_manager_service( .block_on(OutputManagerService::new( OutputManagerServiceConfig { base_node_query_timeout: Duration::from_secs(10), + max_utxo_query_size: 2, }, outbound_message_requester.clone(), ts_handle.clone(), @@ -648,19 +653,60 @@ fn test_startup_utxo_scan() { let factories = CryptoFactories::default(); let mut runtime = Runtime::new().unwrap(); + let backend = OutputManagerMemoryDatabase::new(); + + let invalid_key = PrivateKey::random(&mut OsRng); + let invalid_value = 666; + let invalid_output = UnblindedOutput::new(MicroTari::from(invalid_value), invalid_key.clone(), None); + let invalid_hash = invalid_output.as_transaction_output(&factories).unwrap().hash(); + + backend + .write(WriteOperation::Insert(DbKeyValuePair::UnspentOutput( + invalid_output.spending_key.clone(), + Box::new(DbUnblindedOutput::from_unblinded_output(invalid_output.clone(), &factories).unwrap()), + ))) + .unwrap(); + backend + .invalidate_unspent_output( + &DbUnblindedOutput::from_unblinded_output(invalid_output.clone(), &factories).unwrap(), + ) + .unwrap(); let (mut oms, outbound_service, _shutdown, mut base_node_response_sender, _) = - setup_output_manager_service(&mut runtime, OutputManagerMemoryDatabase::new()); + setup_output_manager_service(&mut runtime, backend); + let mut hashes = Vec::new(); let key1 = PrivateKey::random(&mut OsRng); let value1 = 500; - let output1 = UnblindedOutput::new(MicroTari::from(value1), key1, None); - + let output1 = UnblindedOutput::new(MicroTari::from(value1), key1.clone(), None); + let tx_output1 = output1.as_transaction_output(&factories).unwrap(); + let output1_hash = tx_output1.hash(); + hashes.push(output1_hash.clone()); runtime.block_on(oms.add_output(output1.clone())).unwrap(); + let key2 = PrivateKey::random(&mut OsRng); let value2 = 800; - let output2 = UnblindedOutput::new(MicroTari::from(value2), key2, None); + let output2 = UnblindedOutput::new(MicroTari::from(value2), key2.clone(), None); + let tx_output2 = output2.as_transaction_output(&factories).unwrap(); + hashes.push(tx_output2.hash()); + runtime.block_on(oms.add_output(output2.clone())).unwrap(); + let key3 = PrivateKey::random(&mut OsRng); + let value3 = 900; + let output3 = UnblindedOutput::new(MicroTari::from(value3), key3.clone(), None); + let tx_output3 = output3.as_transaction_output(&factories).unwrap(); + hashes.push(tx_output3.hash()); + + runtime.block_on(oms.add_output(output3.clone())).unwrap(); + + let key4 = PrivateKey::random(&mut OsRng); + let value4 = 901; + let output4 = UnblindedOutput::new(MicroTari::from(value4), key4.clone(), None); + let tx_output4 = output4.as_transaction_output(&factories).unwrap(); + hashes.push(tx_output4.hash()); + + runtime.block_on(oms.add_output(output4.clone())).unwrap(); + let base_node_identity = NodeIdentity::random( &mut OsRng, "/ip4/127.0.0.1/tcp/58217".parse().unwrap(), @@ -672,16 +718,60 @@ fn test_startup_utxo_scan() { .block_on(oms.set_base_node_public_key(base_node_identity.public_key().clone())) .unwrap(); + outbound_service + .wait_call_count(3, Duration::from_secs(60)) + .expect("call wait 1"); + + let (_, _) = outbound_service.pop_call().unwrap(); // Burn the invalid request + + let (_, body) = outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); + let bn_request1: BaseNodeProto::BaseNodeServiceRequest = envelope_body + .decode_part::(1) + .unwrap() + .unwrap(); + + match bn_request1.request { + None => assert!(false, "Invalid request"), + Some(request) => match request { + Request::FetchUtxos(hash_outputs) => { + for h in hash_outputs.outputs { + assert!(hashes.iter().find(|i| **i == h).is_some(), "Should contain hash"); + } + }, + _ => assert!(false, "invalid request"), + }, + } + + let (_, body) = outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); + let bn_request2: BaseNodeProto::BaseNodeServiceRequest = envelope_body + .decode_part::(1) + .unwrap() + .unwrap(); + + match bn_request2.request { + None => assert!(false, "Invalid request"), + Some(request) => match request { + Request::FetchUtxos(hash_outputs) => { + for h in hash_outputs.outputs { + assert!(hashes.iter().find(|i| **i == h).is_some(), "Should contain hash2"); + } + }, + _ => assert!(false, "invalid request"), + }, + } + let result_stream = runtime.block_on(async { collect_stream!( oms.get_event_stream_fused().map(|i| (*i).clone()), - take = 1, + take = 3, timeout = Duration::from_secs(60) ) }); assert_eq!( - 1, + 3, result_stream.iter().fold(0, |acc, item| { if let OutputManagerEvent::BaseNodeSyncRequestTimedOut(_) = item { acc + 1 @@ -691,18 +781,88 @@ fn test_startup_utxo_scan() { }) ); - let key3 = PrivateKey::random(&mut OsRng); - let value3 = 900; - let output3 = UnblindedOutput::new(MicroTari::from(value3), key3, None); - runtime.block_on(oms.add_output(output3.clone())).unwrap(); + // Test the response to the revalidation call first so as not to confuse the invalidation that happens during the + // responses to the Unspent UTXO queries + let mut invalid_request_key = 0; + let mut unspent_request_key_with_output1 = 0; + + for _ in 0..3 { + let (_, body) = outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); + let bn_request: BaseNodeProto::BaseNodeServiceRequest = envelope_body + .decode_part::(1) + .unwrap() + .unwrap(); - let (_, body) = outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); - let bn_request: BaseNodeProto::BaseNodeServiceRequest = envelope_body - .decode_part::(1) - .unwrap() + let request_hashes = if let Request::FetchUtxos(outputs) = bn_request.request.unwrap() { + outputs.outputs + } else { + assert!(false, "Wrong request type"); + Vec::new() + }; + + if request_hashes.iter().find(|i| **i == invalid_hash).is_some() { + invalid_request_key = bn_request.request_key; + } + if request_hashes.iter().find(|i| **i == output1_hash).is_some() { + unspent_request_key_with_output1 = bn_request.request_key; + } + } + assert_ne!(invalid_request_key, 0, "Should have found invalid request key"); + assert_ne!( + unspent_request_key_with_output1, 0, + "Should have found request key for request with output 1 in it" + ); + + let invalid_txs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); + assert_eq!(invalid_txs.len(), 1); + let base_node_response = BaseNodeProto::BaseNodeServiceResponse { + request_key: invalid_request_key, + response: Some(BaseNodeResponseProto::TransactionOutputs( + BaseNodeProto::TransactionOutputs { + outputs: vec![invalid_output.clone().as_transaction_output(&factories).unwrap().into()].into(), + }, + )), + }; + + runtime + .block_on(base_node_response_sender.send(create_dummy_message( + base_node_response, + base_node_identity.public_key(), + ))) .unwrap(); + let mut event_stream = oms.get_event_stream_fused(); + + runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut acc = 0; + loop { + futures::select! { + event = event_stream.select_next_some() => { + if let OutputManagerEvent::ReceiveBaseNodeResponse(_) = (*event).clone() { + acc += 1; + if acc >= 1 { + break; + } + } + }, + () = delay => { + break; + }, + } + } + assert!(acc >= 1, "Did not receive enough responses"); + }); + + let invalid_txs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); + assert_eq!(invalid_txs.len(), 0); + + let key4 = PrivateKey::random(&mut OsRng); + let value4 = 1000; + let output4 = UnblindedOutput::new(MicroTari::from(value4), key4, None); + runtime.block_on(oms.add_output(output4.clone())).unwrap(); + let invalid_txs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); assert_eq!(invalid_txs.len(), 0); @@ -726,7 +886,7 @@ fn test_startup_utxo_scan() { assert_eq!(invalid_txs.len(), 0); let base_node_response = BaseNodeProto::BaseNodeServiceResponse { - request_key: bn_request.request_key.clone(), + request_key: unspent_request_key_with_output1, response: Some(BaseNodeResponseProto::TransactionOutputs( BaseNodeProto::TransactionOutputs { outputs: vec![output1.clone().as_transaction_output(&factories).unwrap().into()].into(), @@ -741,33 +901,43 @@ fn test_startup_utxo_scan() { ))) .unwrap(); - let result_stream = runtime.block_on(async { - collect_stream!( - oms.get_event_stream_fused().map(|i| (*i).clone()), - take = 2, - timeout = Duration::from_secs(60) - ) - }); - - assert_eq!( - 1, - result_stream.iter().fold(0, |acc, item| { - if let OutputManagerEvent::ReceiveBaseNodeResponse(_) = item { - acc + 1 - } else { - acc + runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut acc = 0; + loop { + futures::select! { + event = event_stream.select_next_some() => { + if let OutputManagerEvent::ReceiveBaseNodeResponse(_) = (*event).clone() { + acc += 1; + if acc >= 1 { + break; + } + } + }, + () = delay => { + break; + }, } - }) - ); + } + assert!(acc >= 1, "Did not receive enough responses"); + }); let invalid_outputs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); assert_eq!(invalid_outputs.len(), 1); - assert_eq!(invalid_outputs[0], output2); + let check2 = invalid_outputs[0] == output2; + let check3 = invalid_outputs[0] == output3; + let check4 = invalid_outputs[0] == output4; + + assert!(check2 || check3 || check4, "One of these outputs should be invalid"); let unspent_outputs = runtime.block_on(oms.get_unspent_outputs()).unwrap(); - assert_eq!(unspent_outputs.len(), 2); + assert_eq!(unspent_outputs.len(), 5); assert!(unspent_outputs.iter().find(|uo| uo == &&output1).is_some()); - assert!(unspent_outputs.iter().find(|uo| uo == &&output3).is_some()); + if check2 { + assert!(unspent_outputs.iter().find(|uo| uo == &&output3).is_some()) + } else { + assert!(unspent_outputs.iter().find(|uo| uo == &&output2).is_some()) + } runtime.block_on(oms.sync_with_base_node()).unwrap(); @@ -778,6 +948,20 @@ fn test_startup_utxo_scan() { .unwrap() .unwrap(); + let (_, body) = outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); + let bn_request2: BaseNodeProto::BaseNodeServiceRequest = envelope_body + .decode_part::(1) + .unwrap() + .unwrap(); + + let (_, body) = outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); + let bn_request3: BaseNodeProto::BaseNodeServiceRequest = envelope_body + .decode_part::(1) + .unwrap() + .unwrap(); + let invalid_txs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); assert_eq!(invalid_txs.len(), 1); @@ -787,6 +971,7 @@ fn test_startup_utxo_scan() { BaseNodeProto::TransactionOutputs { outputs: vec![].into() }, )), }; + runtime .block_on(base_node_response_sender.send(create_dummy_message( base_node_response, @@ -794,6 +979,32 @@ fn test_startup_utxo_scan() { ))) .unwrap(); + let base_node_response2 = BaseNodeProto::BaseNodeServiceResponse { + request_key: bn_request2.request_key.clone(), + response: Some(BaseNodeResponseProto::TransactionOutputs( + BaseNodeProto::TransactionOutputs { outputs: vec![].into() }, + )), + }; + runtime + .block_on(base_node_response_sender.send(create_dummy_message( + base_node_response2, + base_node_identity.public_key(), + ))) + .unwrap(); + + let base_node_response3 = BaseNodeProto::BaseNodeServiceResponse { + request_key: bn_request3.request_key.clone(), + response: Some(BaseNodeResponseProto::TransactionOutputs( + BaseNodeProto::TransactionOutputs { outputs: vec![].into() }, + )), + }; + runtime + .block_on(base_node_response_sender.send(create_dummy_message( + base_node_response3, + base_node_identity.public_key(), + ))) + .unwrap(); + let mut event_stream = oms.get_event_stream_fused(); runtime.block_on(async { @@ -802,9 +1013,9 @@ fn test_startup_utxo_scan() { loop { futures::select! { event = event_stream.select_next_some() => { - if let OutputManagerEvent::ReceiveBaseNodeResponse(_) = (*event).clone() { + if let OutputManagerEvent::ReceiveBaseNodeResponse(r) = (*event).clone() { acc += 1; - if acc >= 2 { + if acc >= 4 { break; } } @@ -814,11 +1025,11 @@ fn test_startup_utxo_scan() { }, } } - assert!(acc >= 2, "Did not receive enough responses"); + assert!(acc >= 3, "Did not receive enough responses"); }); let invalid_txs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); - assert_eq!(invalid_txs.len(), 3); + assert_eq!(invalid_txs.len(), 6); } fn sending_transaction_with_short_term_clear(backend: T) { diff --git a/base_layer/wallet/tests/output_manager_service/storage.rs b/base_layer/wallet/tests/output_manager_service/storage.rs index d0092c38de..b3daf35d50 100644 --- a/base_layer/wallet/tests/output_manager_service/storage.rs +++ b/base_layer/wallet/tests/output_manager_service/storage.rs @@ -290,6 +290,7 @@ pub fn test_db_backend(backend: T) { .block_on(db.invalidate_output(unspent_outputs[0].clone())) .unwrap(); let invalid_outputs = runtime.block_on(db.get_invalid_outputs()).unwrap(); + assert_eq!(invalid_outputs.len(), 1); assert_eq!(invalid_outputs[0], unspent_outputs[0]); @@ -297,6 +298,40 @@ pub fn test_db_backend(backend: T) { .block_on(db.invalidate_output(pending_txs[0].outputs_to_be_received[0].clone())) .unwrap(); assert_eq!(tx_id, Some(pending_txs[0].tx_id)); + + // test revalidating output + let unspent_outputs = runtime.block_on(db.get_unspent_outputs()).unwrap(); + assert!( + unspent_outputs + .iter() + .find(|o| o.unblinded_output == invalid_outputs[0].unblinded_output) + .is_none(), + "Should not find ouput" + ); + + assert!(runtime + .block_on( + db.revalidate_output( + pending_txs[2].outputs_to_be_spent[0] + .unblinded_output + .spending_key + .clone() + ) + ) + .is_err()); + runtime + .block_on(db.revalidate_output(invalid_outputs[0].unblinded_output.spending_key.clone())) + .unwrap(); + let new_invalid_outputs = runtime.block_on(db.get_invalid_outputs()).unwrap(); + assert_eq!(new_invalid_outputs.len(), 1); + let unspent_outputs = runtime.block_on(db.get_unspent_outputs()).unwrap(); + assert!( + unspent_outputs + .iter() + .find(|o| o.unblinded_output == invalid_outputs[0].unblinded_output) + .is_some(), + "Should find revalidated ouput" + ); } #[test] diff --git a/comms/dht/src/outbound/mock.rs b/comms/dht/src/outbound/mock.rs index 5ef1dc96dc..c6cb7d2c1e 100644 --- a/comms/dht/src/outbound/mock.rs +++ b/comms/dht/src/outbound/mock.rs @@ -36,6 +36,7 @@ use futures::{ stream::Fuse, StreamExt, }; +use log::*; use std::{ sync::{Arc, Condvar, Mutex, RwLock}, time::Duration, @@ -43,6 +44,8 @@ use std::{ use tari_comms::message::MessageTag; use tokio::time::delay_for; +const LOG_TARGET: &str = "mock::outbound_requester"; + /// Creates a mock outbound request "handler" for testing purposes. /// /// Each time a request is expected, handle_next should be called. @@ -163,6 +166,11 @@ impl OutboundServiceMock { while let Some(req) = self.receiver.next().await { match req { DhtOutboundRequest::SendMessage(params, body, reply_tx) => { + trace!( + target: LOG_TARGET, + "Send message request received with length of {} bytes", + body.len() + ); let behaviour = self.mock_state.get_behaviour(); match (*params).clone().broadcast_strategy {