From f0dec6419173c3d81864f67d1aa5293921d0f9c3 Mon Sep 17 00:00:00 2001 From: Philip Robinson Date: Fri, 5 Jun 2020 11:48:27 +0200 Subject: [PATCH] Add revalidation check to Output Manager service Outputs can become marked as invalid by the Output Manager service when they do not appear in the base nodes blockchain. Generally this means that they were re-orged out or have been spent by another copy of the wallet. However they can be reported as invalid incorrectly in some case like, for example, if the Base Node is not fully synced. This PR adds in an addition check that the Output Manager does on startup that will see if any of the invalid outputs has become valid again. If it has that output will changed back into a spendable output. --- .../src/output_manager_service/config.rs | 2 + .../src/output_manager_service/service.rs | 427 +++++++++++++----- .../storage/database.rs | 10 + .../storage/memory_db.rs | 17 + .../storage/sqlite_db.rs | 19 +- .../tests/output_manager_service/service.rs | 297 ++++++++++-- .../tests/output_manager_service/storage.rs | 35 ++ comms/dht/src/outbound/mock.rs | 8 + 8 files changed, 647 insertions(+), 168 deletions(-) diff --git a/base_layer/wallet/src/output_manager_service/config.rs b/base_layer/wallet/src/output_manager_service/config.rs index 9b7ba8c89c..9c8eecc2ba 100644 --- a/base_layer/wallet/src/output_manager_service/config.rs +++ b/base_layer/wallet/src/output_manager_service/config.rs @@ -25,12 +25,14 @@ use std::time::Duration; #[derive(Clone)] pub struct OutputManagerServiceConfig { pub base_node_query_timeout: Duration, + pub max_utxo_query_size: usize, } impl Default for OutputManagerServiceConfig { fn default() -> Self { Self { base_node_query_timeout: Duration::from_secs(30), + max_utxo_query_size: 5000, } } } diff --git a/base_layer/wallet/src/output_manager_service/service.rs b/base_layer/wallet/src/output_manager_service/service.rs index a21149a4aa..f94c59c2ee 100644 --- a/base_layer/wallet/src/output_manager_service/service.rs +++ b/base_layer/wallet/src/output_manager_service/service.rs @@ -38,7 +38,7 @@ use crate::{ use futures::{future::BoxFuture, pin_mut, stream::FuturesUnordered, FutureExt, SinkExt, Stream, StreamExt}; use log::*; use rand::{rngs::OsRng, RngCore}; -use std::{cmp::Ordering, collections::HashMap, convert::TryFrom, fmt, sync::Mutex, time::Duration}; +use std::{cmp, cmp::Ordering, collections::HashMap, convert::TryFrom, fmt, sync::Mutex, time::Duration}; use tari_broadcast_channel::Publisher; use tari_comms::types::CommsPublicKey; use tari_comms_dht::{ @@ -99,6 +99,7 @@ where TBackend: OutputManagerBackend + 'static factories: CryptoFactories, base_node_public_key: Option, pending_utxo_query_keys: HashMap>>, + pending_revalidation_query_keys: HashMap>>, event_publisher: Publisher, } @@ -155,6 +156,7 @@ where factories, base_node_public_key: None, pending_utxo_query_keys: HashMap::new(), + pending_revalidation_query_keys: HashMap::new(), event_publisher, }) } @@ -195,7 +197,7 @@ where let (origin_public_key, inner_msg) = msg.clone().into_origin_and_inner(); trace!(target: LOG_TARGET, "Handling Base Node Response, Trace: {}", msg.dht_header.message_tag); let result = self.handle_base_node_response(inner_msg).await.or_else(|resp| { - error!(target: LOG_TARGET, "Error handling base node service response from {}: {:?}", origin_public_key, resp); + error!(target: LOG_TARGET, "Error handling base node service response from {}: {:?}, Trace: {}", origin_public_key, resp, msg.dht_header.message_tag); Err(resp) }); @@ -324,86 +326,153 @@ where }, }; - // Only process requests with a request_key that we are expecting. - let queried_hashes: Vec> = match self.pending_utxo_query_keys.remove(&request_key) { - None => { - trace!( - target: LOG_TARGET, - "Ignoring Base Node Response with unexpected request key ({}), it was not meant for this service.", - request_key - ); - return Ok(()); - }, - Some(qh) => qh, - }; + let mut unspent_query_handled = false; + let mut invalid_query_handled = false; - trace!( - target: LOG_TARGET, - "Handling a Base Node Response meant for this service" - ); + // Check if the received key is in the pending UTXO query list to be handled + if let Some(queried_hashes) = self.pending_utxo_query_keys.remove(&request_key) { + trace!( + target: LOG_TARGET, + "Handling a Base Node Response for a Unspent Outputs request ({})", + request_key + ); - // Construct a HashMap of all the unspent outputs - let unspent_outputs: Vec = self.db.get_unspent_outputs().await?; + // Construct a HashMap of all the unspent outputs + let unspent_outputs: Vec = self.db.get_unspent_outputs().await?; - let mut output_hashes = HashMap::new(); - for uo in unspent_outputs.iter() { - let hash = uo.hash.clone(); - if queried_hashes.iter().any(|h| &hash == h) { - output_hashes.insert(hash.clone(), uo.clone()); + let mut output_hashes = HashMap::new(); + for uo in unspent_outputs.iter() { + let hash = uo.hash.clone(); + if queried_hashes.iter().any(|h| &hash == h) { + output_hashes.insert(hash.clone(), uo.clone()); + } } - } - // Go through all the returned UTXOs and if they are in the hashmap remove them - for output in response.iter() { - let response_hash = TransactionOutput::try_from(output.clone()) - .map_err(OutputManagerError::ConversionError)? - .hash(); + // Go through all the returned UTXOs and if they are in the hashmap remove them + for output in response.iter() { + let response_hash = TransactionOutput::try_from(output.clone()) + .map_err(OutputManagerError::ConversionError)? + .hash(); - let _ = output_hashes.remove(&response_hash); - } + let _ = output_hashes.remove(&response_hash); + } - // If there are any remaining Unspent Outputs we will move them to the invalid collection - for (_k, v) in output_hashes { - // Get the transaction these belonged to so we can display the kernel signature of the transaction this - // output belonged to. + // If there are any remaining Unspent Outputs we will move them to the invalid collection + for (_k, v) in output_hashes { + // Get the transaction these belonged to so we can display the kernel signature of the transaction + // this output belonged to. - warn!( - target: LOG_TARGET, - "Output with value {} not returned from Base Node query and is thus being invalidated", - v.unblinded_output.value - ); - // If the output that is being invalidated has an associated TxId then get the kernel signature of the - // transaction and display for easier debugging - if let Some(tx_id) = self.db.invalidate_output(v).await? { - if let Ok(transaction) = self.transaction_service.get_completed_transaction(tx_id).await { + warn!( + target: LOG_TARGET, + "Output with value {} not returned from Base Node query and is thus being invalidated", + v.unblinded_output.value + ); + // If the output that is being invalidated has an associated TxId then get the kernel signature of + // the transaction and display for easier debugging + if let Some(tx_id) = self.db.invalidate_output(v).await? { + if let Ok(transaction) = self.transaction_service.get_completed_transaction(tx_id).await { + info!( + target: LOG_TARGET, + "Invalidated Output is from Transaction (TxId: {}) with message: {} and Kernel Signature: \ + {}", + transaction.tx_id, + transaction.message, + transaction.transaction.body.kernels()[0] + .excess_sig + .get_signature() + .to_hex() + ) + } + } else { info!( target: LOG_TARGET, - "Invalidated Output is from Transaction (TxId: {}) with message: {} and Kernel Signature: {}", - transaction.tx_id, - transaction.message, - transaction.transaction.body.kernels()[0] - .excess_sig - .get_signature() - .to_hex() - ) + "Invalidated Output does not have an associated TxId so it is likely a Coinbase output lost \ + to a Re-Org" + ); } - } else { - info!( - target: LOG_TARGET, - "Invalidated Output does not have an associated TxId so it is likely a Coinbase output lost to a \ - Re-Org" - ); } + unspent_query_handled = true; + debug!( + target: LOG_TARGET, + "Handled Base Node response for Unspent Outputs Query {}", request_key + ); + }; + + // Check if the received key is in the Invalid UTXO query list waiting to be handled + if let Some(_) = self.pending_revalidation_query_keys.remove(&request_key) { + trace!( + target: LOG_TARGET, + "Handling a Base Node Response for a Invalid Outputs request ({})", + request_key + ); + let invalid_outputs = self.db.get_invalid_outputs().await?; + + for output in response.iter() { + let response_hash = TransactionOutput::try_from(output.clone()) + .map_err(OutputManagerError::ConversionError)? + .hash(); + + if let Some(output) = invalid_outputs.iter().find(|o| o.hash == response_hash) { + if let Ok(_) = self + .db + .revalidate_output(output.unblinded_output.spending_key.clone()) + .await + { + trace!( + target: LOG_TARGET, + "Output with value {} has been restored to a valid spendable output", + output.unblinded_output.value + ); + } + } + } + invalid_query_handled = true; + debug!( + target: LOG_TARGET, + "Handled Base Node response for Invalid Outputs Query {}", request_key + ); } - debug!( - target: LOG_TARGET, - "Handled Base Node response for Query {}", request_key - ); + if unspent_query_handled || invalid_query_handled { + let _ = self + .event_publisher + .send(OutputManagerEvent::ReceiveBaseNodeResponse(request_key)) + .await + .map_err(|e| { + trace!( + target: LOG_TARGET, + "Error sending event, usually because there are no subscribers: {:?}", + e + ); + e + }); + } + + Ok(()) + } + + /// Handle the timeout of a pending UTXO query. + pub async fn handle_utxo_query_timeout( + &mut self, + query_key: u64, + utxo_query_timeout_futures: &mut FuturesUnordered>, + ) -> Result<(), OutputManagerError> + { + if let Some(hashes) = self.pending_utxo_query_keys.remove(&query_key) { + warn!(target: LOG_TARGET, "UTXO Unspent Outputs Query {} timed out", query_key); + self.query_outputs_status(utxo_query_timeout_futures, hashes, UtxoQueryType::UnspentOutputs) + .await?; + } + + if let Some(hashes) = self.pending_revalidation_query_keys.remove(&query_key) { + warn!(target: LOG_TARGET, "UTXO Invalid Outputs Query {} timed out", query_key); + self.query_outputs_status(utxo_query_timeout_futures, hashes, UtxoQueryType::InvalidOutputs) + .await?; + } let _ = self .event_publisher - .send(OutputManagerEvent::ReceiveBaseNodeResponse(request_key)) + .send(OutputManagerEvent::BaseNodeSyncRequestTimedOut(query_key)) .await .map_err(|e| { trace!( @@ -413,83 +482,178 @@ where ); e }); - Ok(()) } - /// Handle the timeout of a pending UTXO query. - pub async fn handle_utxo_query_timeout( + pub async fn query_unspent_outputs_status( &mut self, - query_key: u64, utxo_query_timeout_futures: &mut FuturesUnordered>, - ) -> Result<(), OutputManagerError> + ) -> Result { - if self.pending_utxo_query_keys.remove(&query_key).is_some() { - error!(target: LOG_TARGET, "UTXO Query {} timed out", query_key); - self.query_unspent_outputs_status(utxo_query_timeout_futures).await?; - // TODO Remove this once this bug is fixed - trace!(target: LOG_TARGET, "Finished queueing new Base Node query timeout"); - let _ = self - .event_publisher - .send(OutputManagerEvent::BaseNodeSyncRequestTimedOut(query_key)) - .await - .map_err(|e| { - trace!( - target: LOG_TARGET, - "Error sending event, usually because there are no subscribers: {:?}", - e - ); - e - }); + let unspent_output_hashes = self + .db + .get_unspent_outputs() + .await? + .iter() + .map(|uo| uo.hash.clone()) + .collect(); + + let key = self + .query_outputs_status( + utxo_query_timeout_futures, + unspent_output_hashes, + UtxoQueryType::UnspentOutputs, + ) + .await?; + + Ok(key) + } + + pub async fn query_invalid_outputs_status( + &mut self, + utxo_query_timeout_futures: &mut FuturesUnordered>, + ) -> Result + { + let invalid_output_hashes: Vec> = self + .db + .get_invalid_outputs() + .await? + .iter() + .map(|uo| uo.hash.clone()) + .collect(); + + let mut key = 0; + if !invalid_output_hashes.is_empty() { + key = self + .query_outputs_status( + utxo_query_timeout_futures, + invalid_output_hashes, + UtxoQueryType::InvalidOutputs, + ) + .await?; } - Ok(()) + Ok(key) } - /// Send queries to the base node to check the status of all unspent outputs. If the outputs are no longer - /// available their status will be updated in the wallet. - pub async fn query_unspent_outputs_status( + /// Send queries to the base node to check the status of all specified outputs. + async fn query_outputs_status( &mut self, utxo_query_timeout_futures: &mut FuturesUnordered>, + mut outputs_to_query: Vec>, + query_type: UtxoQueryType, ) -> Result { match self.base_node_public_key.as_ref() { None => Err(OutputManagerError::NoBaseNodeKeysProvided), Some(pk) => { - let unspent_outputs: Vec = self.db.get_unspent_outputs().await?; - let mut output_hashes = Vec::new(); - for uo in unspent_outputs.iter() { - let hash = uo.hash.clone(); - output_hashes.push(hash.clone()); - } - let request_key = OsRng.next_u64(); + let mut first_request_key = 0; - let request = BaseNodeRequestProto::FetchUtxos(BaseNodeProto::HashOutputs { - outputs: output_hashes.clone(), - }); + // Determine how many rounds of base node request we need to query all the outputs in batches of + // max_utxo_query_size + let rounds = + ((outputs_to_query.len() as f32) / (self.config.max_utxo_query_size as f32 + 0.1)) as usize + 1; - let service_request = BaseNodeProto::BaseNodeServiceRequest { - request_key, - request: Some(request), - }; - // TODO Remove this once this bug is fixed - trace!(target: LOG_TARGET, "About to attempt to send query to base node"); - self.outbound_message_service - .send_direct( - pk.clone(), - OutboundEncryption::None, - OutboundDomainMessage::new(TariMessageType::BaseNodeRequest, service_request), - ) - .await?; - // TODO Remove this once this bug is fixed - trace!(target: LOG_TARGET, "Query sent to Base Node"); - self.pending_utxo_query_keys.insert(request_key, output_hashes); - let state_timeout = StateDelay::new(self.config.base_node_query_timeout, request_key); - utxo_query_timeout_futures.push(state_timeout.delay().boxed()); - debug!( - target: LOG_TARGET, - "Output Manager Sync query ({}) sent to Base Node", request_key - ); - Ok(request_key) + for r in 0..rounds { + let mut output_hashes = Vec::new(); + for uo_hash in + outputs_to_query.drain(..cmp::min(self.config.max_utxo_query_size, outputs_to_query.len())) + { + output_hashes.push(uo_hash); + } + let request_key = OsRng.next_u64(); + if first_request_key == 0 { + first_request_key = request_key; + } + + let request = BaseNodeRequestProto::FetchUtxos(BaseNodeProto::HashOutputs { + outputs: output_hashes.clone(), + }); + + let service_request = BaseNodeProto::BaseNodeServiceRequest { + request_key, + request: Some(request), + }; + + let send_message_response = self + .outbound_message_service + .send_direct( + pk.clone(), + OutboundEncryption::None, + OutboundDomainMessage::new(TariMessageType::BaseNodeRequest, service_request), + ) + .await?; + + // Here we are going to spawn a non-blocking task that will monitor and log the progress of the + // send process. + let owned_request_key = request_key; + tokio::spawn(async move { + match send_message_response.resolve_ok().await { + None => trace!( + target: LOG_TARGET, + "Failed to send Output Manager UTXO query ({}) to Base Node", + owned_request_key + ), + Some(send_states) => { + if send_states.len() == 1 { + trace!( + target: LOG_TARGET, + "Output Manager UTXO query ({}) queued for sending with Message {}", + owned_request_key, + send_states[0].tag, + ); + let message_tag = send_states[0].tag; + if send_states.wait_single().await { + trace!( + target: LOG_TARGET, + "Output Manager UTXO query ({}) successfully sent to Base Node with \ + Message {}", + owned_request_key, + message_tag, + ) + } else { + trace!( + target: LOG_TARGET, + "Failed to send Output Manager UTXO query ({}) to Base Node with Message \ + {}", + owned_request_key, + message_tag, + ); + } + } else { + trace!( + target: LOG_TARGET, + "Failed to send Output Manager UTXO query ({}) to Base Node", + owned_request_key + ) + } + }, + } + }); + + match query_type { + UtxoQueryType::UnspentOutputs => { + self.pending_utxo_query_keys.insert(request_key, output_hashes); + }, + UtxoQueryType::InvalidOutputs => { + self.pending_revalidation_query_keys.insert(request_key, output_hashes); + }, + } + + let state_timeout = StateDelay::new(self.config.base_node_query_timeout, request_key); + utxo_query_timeout_futures.push(state_timeout.delay().boxed()); + + debug!( + target: LOG_TARGET, + "Output Manager {} query ({}) sent to Base Node, part {} of {} requests", + query_type, + request_key, + r + 1, + rounds + ); + } + // We are just going to return the first request key for use by the front end. It is very unlikely that + // a mobile wallet will ever have this query split up + Ok(first_request_key) }, } } @@ -794,6 +958,7 @@ where if startup_query { self.query_unspent_outputs_status(utxo_query_timeout_futures).await?; + self.query_invalid_outputs_status(utxo_query_timeout_futures).await?; } Ok(()) } @@ -952,3 +1117,17 @@ impl fmt::Display for Balance { Ok(()) } } + +enum UtxoQueryType { + UnspentOutputs, + InvalidOutputs, +} + +impl fmt::Display for UtxoQueryType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + UtxoQueryType::UnspentOutputs => write!(f, "Unspent Outputs"), + UtxoQueryType::InvalidOutputs => write!(f, "Invalid Outputs"), + } + } +} diff --git a/base_layer/wallet/src/output_manager_service/storage/database.rs b/base_layer/wallet/src/output_manager_service/storage/database.rs index 0c3a35be1e..eaea0f0b45 100644 --- a/base_layer/wallet/src/output_manager_service/storage/database.rs +++ b/base_layer/wallet/src/output_manager_service/storage/database.rs @@ -83,6 +83,8 @@ pub trait OutputManagerBackend: Send + Sync { /// If an unspent output is detected as invalid (i.e. not available on the blockchain) then it should be moved to /// the invalid outputs collection. The function will return the last recorded TxId associated with this output. fn invalidate_unspent_output(&self, output: &DbUnblindedOutput) -> Result, OutputManagerStorageError>; + /// If an invalid output is found to be valid this function will turn it back into an unspent output + fn revalidate_unspent_output(&self, spending_key: &BlindingFactor) -> Result<(), OutputManagerStorageError>; } /// Holds the outputs that have been selected for a given pending transaction waiting for confirmation @@ -502,6 +504,14 @@ where T: OutputManagerBackend + 'static .or_else(|err| Err(OutputManagerStorageError::BlockingTaskSpawnError(err.to_string()))) .and_then(|inner_result| inner_result) } + + pub async fn revalidate_output(&self, spending_key: BlindingFactor) -> Result<(), OutputManagerStorageError> { + let db_clone = self.db.clone(); + tokio::task::spawn_blocking(move || db_clone.revalidate_unspent_output(&spending_key)) + .await + .or_else(|err| Err(OutputManagerStorageError::BlockingTaskSpawnError(err.to_string()))) + .and_then(|inner_result| inner_result) + } } fn unexpected_result(req: DbKey, res: DbValue) -> Result { diff --git a/base_layer/wallet/src/output_manager_service/storage/memory_db.rs b/base_layer/wallet/src/output_manager_service/storage/memory_db.rs index 3b6c5fe2c0..5c43865148 100644 --- a/base_layer/wallet/src/output_manager_service/storage/memory_db.rs +++ b/base_layer/wallet/src/output_manager_service/storage/memory_db.rs @@ -42,6 +42,7 @@ use std::{ sync::{Arc, RwLock}, time::Duration, }; +use tari_core::transactions::types::BlindingFactor; /// This structure is an In-Memory database backend that implements the `OutputManagerBackend` trait and provides all /// the functionality required by the trait. @@ -372,6 +373,22 @@ impl OutputManagerBackend for OutputManagerMemoryDatabase { None => Err(OutputManagerStorageError::ValuesNotFound), } } + + fn revalidate_unspent_output(&self, spending_key: &BlindingFactor) -> Result<(), OutputManagerStorageError> { + let mut db = acquire_write_lock!(self.db); + match db + .invalid_outputs + .iter() + .position(|v| v.output.unblinded_output.spending_key == *spending_key) + { + Some(pos) => { + let output = db.invalid_outputs.remove(pos); + db.unspent_outputs.push(output.clone()); + Ok(()) + }, + None => Err(OutputManagerStorageError::ValuesNotFound), + } + } } // A struct that contains the extra info we are using in the Sql version of this backend diff --git a/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs b/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs index c2ea85e7a6..363773a557 100644 --- a/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs +++ b/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs @@ -54,7 +54,7 @@ use tari_core::{ transactions::{ tari_amount::MicroTari, transaction::{OutputFeatures, OutputFlags, UnblindedOutput}, - types::{CryptoFactories, PrivateKey}, + types::{BlindingFactor, CryptoFactories, PrivateKey}, }, }; use tari_crypto::tari_utilities::ByteArray; @@ -429,6 +429,23 @@ impl OutputManagerBackend for OutputManagerSqliteDatabase { Ok(tx_id) } + + fn revalidate_unspent_output(&self, spending_key: &BlindingFactor) -> Result<(), OutputManagerStorageError> { + let conn = acquire_lock!(self.database_connection); + let output = OutputSql::find(&spending_key.to_vec(), &conn)?; + + if OutputStatus::try_from(output.status)? != OutputStatus::Invalid { + return Err(OutputManagerStorageError::ValuesNotFound); + } + let _ = output.update( + UpdateOutput { + status: Some(OutputStatus::Unspent), + tx_id: None, + }, + &(*conn), + )?; + Ok(()) + } } /// A utility function to construct a PendingTransactionOutputs structure for a TxId, set of Outputs and a Timestamp diff --git a/base_layer/wallet/tests/output_manager_service/service.rs b/base_layer/wallet/tests/output_manager_service/service.rs index acfdacf353..00f84bbb95 100644 --- a/base_layer/wallet/tests/output_manager_service/service.rs +++ b/base_layer/wallet/tests/output_manager_service/service.rs @@ -42,7 +42,10 @@ use tari_comms_dht::outbound::mock::{create_outbound_service_mock, OutboundServi use tari_core::{ base_node::proto::{ base_node as BaseNodeProto, - base_node::base_node_service_response::Response as BaseNodeResponseProto, + base_node::{ + base_node_service_request::Request, + base_node_service_response::Response as BaseNodeResponseProto, + }, }, transactions::{ fee::Fee, @@ -57,7 +60,7 @@ use tari_crypto::{ commitment::HomomorphicCommitmentFactory, keys::SecretKey, range_proof::RangeProofService, - tari_utilities::ByteArray, + tari_utilities::{hash::Hashable, ByteArray}, }; use tari_p2p::domain_message::DomainMessage; use tari_service_framework::reply_channel; @@ -70,8 +73,9 @@ use tari_wallet::{ handle::{OutputManagerEvent, OutputManagerHandle}, service::OutputManagerService, storage::{ - database::{DbKey, DbValue, OutputManagerBackend, OutputManagerDatabase}, + database::{DbKey, DbKeyValuePair, DbValue, OutputManagerBackend, OutputManagerDatabase, WriteOperation}, memory_db::OutputManagerMemoryDatabase, + models::DbUnblindedOutput, sqlite_db::OutputManagerSqliteDatabase, }, }, @@ -108,6 +112,7 @@ pub fn setup_output_manager_service( .block_on(OutputManagerService::new( OutputManagerServiceConfig { base_node_query_timeout: Duration::from_secs(10), + max_utxo_query_size: 2, }, outbound_message_requester.clone(), ts_handle.clone(), @@ -648,19 +653,60 @@ fn test_startup_utxo_scan() { let factories = CryptoFactories::default(); let mut runtime = Runtime::new().unwrap(); + let backend = OutputManagerMemoryDatabase::new(); + + let invalid_key = PrivateKey::random(&mut OsRng); + let invalid_value = 666; + let invalid_output = UnblindedOutput::new(MicroTari::from(invalid_value), invalid_key.clone(), None); + let invalid_hash = invalid_output.as_transaction_output(&factories).unwrap().hash(); + + backend + .write(WriteOperation::Insert(DbKeyValuePair::UnspentOutput( + invalid_output.spending_key.clone(), + Box::new(DbUnblindedOutput::from_unblinded_output(invalid_output.clone(), &factories).unwrap()), + ))) + .unwrap(); + backend + .invalidate_unspent_output( + &DbUnblindedOutput::from_unblinded_output(invalid_output.clone(), &factories).unwrap(), + ) + .unwrap(); let (mut oms, outbound_service, _shutdown, mut base_node_response_sender, _) = - setup_output_manager_service(&mut runtime, OutputManagerMemoryDatabase::new()); + setup_output_manager_service(&mut runtime, backend); + let mut hashes = Vec::new(); let key1 = PrivateKey::random(&mut OsRng); let value1 = 500; - let output1 = UnblindedOutput::new(MicroTari::from(value1), key1, None); - + let output1 = UnblindedOutput::new(MicroTari::from(value1), key1.clone(), None); + let tx_output1 = output1.as_transaction_output(&factories).unwrap(); + let output1_hash = tx_output1.hash(); + hashes.push(output1_hash.clone()); runtime.block_on(oms.add_output(output1.clone())).unwrap(); + let key2 = PrivateKey::random(&mut OsRng); let value2 = 800; - let output2 = UnblindedOutput::new(MicroTari::from(value2), key2, None); + let output2 = UnblindedOutput::new(MicroTari::from(value2), key2.clone(), None); + let tx_output2 = output2.as_transaction_output(&factories).unwrap(); + hashes.push(tx_output2.hash()); + runtime.block_on(oms.add_output(output2.clone())).unwrap(); + let key3 = PrivateKey::random(&mut OsRng); + let value3 = 900; + let output3 = UnblindedOutput::new(MicroTari::from(value3), key3.clone(), None); + let tx_output3 = output3.as_transaction_output(&factories).unwrap(); + hashes.push(tx_output3.hash()); + + runtime.block_on(oms.add_output(output3.clone())).unwrap(); + + let key4 = PrivateKey::random(&mut OsRng); + let value4 = 901; + let output4 = UnblindedOutput::new(MicroTari::from(value4), key4.clone(), None); + let tx_output4 = output4.as_transaction_output(&factories).unwrap(); + hashes.push(tx_output4.hash()); + + runtime.block_on(oms.add_output(output4.clone())).unwrap(); + let base_node_identity = NodeIdentity::random( &mut OsRng, "/ip4/127.0.0.1/tcp/58217".parse().unwrap(), @@ -672,16 +718,60 @@ fn test_startup_utxo_scan() { .block_on(oms.set_base_node_public_key(base_node_identity.public_key().clone())) .unwrap(); + outbound_service + .wait_call_count(3, Duration::from_secs(60)) + .expect("call wait 1"); + + let (_, _) = outbound_service.pop_call().unwrap(); // Burn the invalid request + + let (_, body) = outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); + let bn_request1: BaseNodeProto::BaseNodeServiceRequest = envelope_body + .decode_part::(1) + .unwrap() + .unwrap(); + + match bn_request1.request { + None => assert!(false, "Invalid request"), + Some(request) => match request { + Request::FetchUtxos(hash_outputs) => { + for h in hash_outputs.outputs { + assert!(hashes.iter().find(|i| **i == h).is_some(), "Should contain hash"); + } + }, + _ => assert!(false, "invalid request"), + }, + } + + let (_, body) = outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); + let bn_request2: BaseNodeProto::BaseNodeServiceRequest = envelope_body + .decode_part::(1) + .unwrap() + .unwrap(); + + match bn_request2.request { + None => assert!(false, "Invalid request"), + Some(request) => match request { + Request::FetchUtxos(hash_outputs) => { + for h in hash_outputs.outputs { + assert!(hashes.iter().find(|i| **i == h).is_some(), "Should contain hash2"); + } + }, + _ => assert!(false, "invalid request"), + }, + } + let result_stream = runtime.block_on(async { collect_stream!( oms.get_event_stream_fused().map(|i| (*i).clone()), - take = 1, + take = 3, timeout = Duration::from_secs(60) ) }); assert_eq!( - 1, + 3, result_stream.iter().fold(0, |acc, item| { if let OutputManagerEvent::BaseNodeSyncRequestTimedOut(_) = item { acc + 1 @@ -691,18 +781,88 @@ fn test_startup_utxo_scan() { }) ); - let key3 = PrivateKey::random(&mut OsRng); - let value3 = 900; - let output3 = UnblindedOutput::new(MicroTari::from(value3), key3, None); - runtime.block_on(oms.add_output(output3.clone())).unwrap(); + // Test the response to the revalidation call first so as not to confuse the invalidation that happens during the + // responses to the Unspent UTXO queries + let mut invalid_request_key = 0; + let mut unspent_request_key_with_output1 = 0; + + for _ in 0..3 { + let (_, body) = outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); + let bn_request: BaseNodeProto::BaseNodeServiceRequest = envelope_body + .decode_part::(1) + .unwrap() + .unwrap(); - let (_, body) = outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); - let bn_request: BaseNodeProto::BaseNodeServiceRequest = envelope_body - .decode_part::(1) - .unwrap() + let request_hashes = if let Request::FetchUtxos(outputs) = bn_request.request.unwrap() { + outputs.outputs + } else { + assert!(false, "Wrong request type"); + Vec::new() + }; + + if request_hashes.iter().find(|i| **i == invalid_hash).is_some() { + invalid_request_key = bn_request.request_key; + } + if request_hashes.iter().find(|i| **i == output1_hash).is_some() { + unspent_request_key_with_output1 = bn_request.request_key; + } + } + assert_ne!(invalid_request_key, 0, "Should have found invalid request key"); + assert_ne!( + unspent_request_key_with_output1, 0, + "Should have found request key for request with output 1 in it" + ); + + let invalid_txs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); + assert_eq!(invalid_txs.len(), 1); + let base_node_response = BaseNodeProto::BaseNodeServiceResponse { + request_key: invalid_request_key, + response: Some(BaseNodeResponseProto::TransactionOutputs( + BaseNodeProto::TransactionOutputs { + outputs: vec![invalid_output.clone().as_transaction_output(&factories).unwrap().into()].into(), + }, + )), + }; + + runtime + .block_on(base_node_response_sender.send(create_dummy_message( + base_node_response, + base_node_identity.public_key(), + ))) .unwrap(); + let mut event_stream = oms.get_event_stream_fused(); + + runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut acc = 0; + loop { + futures::select! { + event = event_stream.select_next_some() => { + if let OutputManagerEvent::ReceiveBaseNodeResponse(_) = (*event).clone() { + acc += 1; + if acc >= 1 { + break; + } + } + }, + () = delay => { + break; + }, + } + } + assert!(acc >= 1, "Did not receive enough responses"); + }); + + let invalid_txs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); + assert_eq!(invalid_txs.len(), 0); + + let key4 = PrivateKey::random(&mut OsRng); + let value4 = 1000; + let output4 = UnblindedOutput::new(MicroTari::from(value4), key4, None); + runtime.block_on(oms.add_output(output4.clone())).unwrap(); + let invalid_txs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); assert_eq!(invalid_txs.len(), 0); @@ -726,7 +886,7 @@ fn test_startup_utxo_scan() { assert_eq!(invalid_txs.len(), 0); let base_node_response = BaseNodeProto::BaseNodeServiceResponse { - request_key: bn_request.request_key.clone(), + request_key: unspent_request_key_with_output1, response: Some(BaseNodeResponseProto::TransactionOutputs( BaseNodeProto::TransactionOutputs { outputs: vec![output1.clone().as_transaction_output(&factories).unwrap().into()].into(), @@ -741,33 +901,43 @@ fn test_startup_utxo_scan() { ))) .unwrap(); - let result_stream = runtime.block_on(async { - collect_stream!( - oms.get_event_stream_fused().map(|i| (*i).clone()), - take = 2, - timeout = Duration::from_secs(60) - ) - }); - - assert_eq!( - 1, - result_stream.iter().fold(0, |acc, item| { - if let OutputManagerEvent::ReceiveBaseNodeResponse(_) = item { - acc + 1 - } else { - acc + runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut acc = 0; + loop { + futures::select! { + event = event_stream.select_next_some() => { + if let OutputManagerEvent::ReceiveBaseNodeResponse(_) = (*event).clone() { + acc += 1; + if acc >= 1 { + break; + } + } + }, + () = delay => { + break; + }, } - }) - ); + } + assert!(acc >= 1, "Did not receive enough responses"); + }); let invalid_outputs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); assert_eq!(invalid_outputs.len(), 1); - assert_eq!(invalid_outputs[0], output2); + let check2 = invalid_outputs[0] == output2; + let check3 = invalid_outputs[0] == output3; + let check4 = invalid_outputs[0] == output4; + + assert!(check2 || check3 || check4, "One of these outputs should be invalid"); let unspent_outputs = runtime.block_on(oms.get_unspent_outputs()).unwrap(); - assert_eq!(unspent_outputs.len(), 2); + assert_eq!(unspent_outputs.len(), 5); assert!(unspent_outputs.iter().find(|uo| uo == &&output1).is_some()); - assert!(unspent_outputs.iter().find(|uo| uo == &&output3).is_some()); + if check2 { + assert!(unspent_outputs.iter().find(|uo| uo == &&output3).is_some()) + } else { + assert!(unspent_outputs.iter().find(|uo| uo == &&output2).is_some()) + } runtime.block_on(oms.sync_with_base_node()).unwrap(); @@ -778,6 +948,20 @@ fn test_startup_utxo_scan() { .unwrap() .unwrap(); + let (_, body) = outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); + let bn_request2: BaseNodeProto::BaseNodeServiceRequest = envelope_body + .decode_part::(1) + .unwrap() + .unwrap(); + + let (_, body) = outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); + let bn_request3: BaseNodeProto::BaseNodeServiceRequest = envelope_body + .decode_part::(1) + .unwrap() + .unwrap(); + let invalid_txs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); assert_eq!(invalid_txs.len(), 1); @@ -787,6 +971,7 @@ fn test_startup_utxo_scan() { BaseNodeProto::TransactionOutputs { outputs: vec![].into() }, )), }; + runtime .block_on(base_node_response_sender.send(create_dummy_message( base_node_response, @@ -794,6 +979,32 @@ fn test_startup_utxo_scan() { ))) .unwrap(); + let base_node_response2 = BaseNodeProto::BaseNodeServiceResponse { + request_key: bn_request2.request_key.clone(), + response: Some(BaseNodeResponseProto::TransactionOutputs( + BaseNodeProto::TransactionOutputs { outputs: vec![].into() }, + )), + }; + runtime + .block_on(base_node_response_sender.send(create_dummy_message( + base_node_response2, + base_node_identity.public_key(), + ))) + .unwrap(); + + let base_node_response3 = BaseNodeProto::BaseNodeServiceResponse { + request_key: bn_request3.request_key.clone(), + response: Some(BaseNodeResponseProto::TransactionOutputs( + BaseNodeProto::TransactionOutputs { outputs: vec![].into() }, + )), + }; + runtime + .block_on(base_node_response_sender.send(create_dummy_message( + base_node_response3, + base_node_identity.public_key(), + ))) + .unwrap(); + let mut event_stream = oms.get_event_stream_fused(); runtime.block_on(async { @@ -802,9 +1013,9 @@ fn test_startup_utxo_scan() { loop { futures::select! { event = event_stream.select_next_some() => { - if let OutputManagerEvent::ReceiveBaseNodeResponse(_) = (*event).clone() { + if let OutputManagerEvent::ReceiveBaseNodeResponse(r) = (*event).clone() { acc += 1; - if acc >= 2 { + if acc >= 4 { break; } } @@ -814,11 +1025,11 @@ fn test_startup_utxo_scan() { }, } } - assert!(acc >= 2, "Did not receive enough responses"); + assert!(acc >= 3, "Did not receive enough responses"); }); let invalid_txs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); - assert_eq!(invalid_txs.len(), 3); + assert_eq!(invalid_txs.len(), 6); } fn sending_transaction_with_short_term_clear(backend: T) { diff --git a/base_layer/wallet/tests/output_manager_service/storage.rs b/base_layer/wallet/tests/output_manager_service/storage.rs index d0092c38de..b3daf35d50 100644 --- a/base_layer/wallet/tests/output_manager_service/storage.rs +++ b/base_layer/wallet/tests/output_manager_service/storage.rs @@ -290,6 +290,7 @@ pub fn test_db_backend(backend: T) { .block_on(db.invalidate_output(unspent_outputs[0].clone())) .unwrap(); let invalid_outputs = runtime.block_on(db.get_invalid_outputs()).unwrap(); + assert_eq!(invalid_outputs.len(), 1); assert_eq!(invalid_outputs[0], unspent_outputs[0]); @@ -297,6 +298,40 @@ pub fn test_db_backend(backend: T) { .block_on(db.invalidate_output(pending_txs[0].outputs_to_be_received[0].clone())) .unwrap(); assert_eq!(tx_id, Some(pending_txs[0].tx_id)); + + // test revalidating output + let unspent_outputs = runtime.block_on(db.get_unspent_outputs()).unwrap(); + assert!( + unspent_outputs + .iter() + .find(|o| o.unblinded_output == invalid_outputs[0].unblinded_output) + .is_none(), + "Should not find ouput" + ); + + assert!(runtime + .block_on( + db.revalidate_output( + pending_txs[2].outputs_to_be_spent[0] + .unblinded_output + .spending_key + .clone() + ) + ) + .is_err()); + runtime + .block_on(db.revalidate_output(invalid_outputs[0].unblinded_output.spending_key.clone())) + .unwrap(); + let new_invalid_outputs = runtime.block_on(db.get_invalid_outputs()).unwrap(); + assert_eq!(new_invalid_outputs.len(), 1); + let unspent_outputs = runtime.block_on(db.get_unspent_outputs()).unwrap(); + assert!( + unspent_outputs + .iter() + .find(|o| o.unblinded_output == invalid_outputs[0].unblinded_output) + .is_some(), + "Should find revalidated ouput" + ); } #[test] diff --git a/comms/dht/src/outbound/mock.rs b/comms/dht/src/outbound/mock.rs index 5ef1dc96dc..c6cb7d2c1e 100644 --- a/comms/dht/src/outbound/mock.rs +++ b/comms/dht/src/outbound/mock.rs @@ -36,6 +36,7 @@ use futures::{ stream::Fuse, StreamExt, }; +use log::*; use std::{ sync::{Arc, Condvar, Mutex, RwLock}, time::Duration, @@ -43,6 +44,8 @@ use std::{ use tari_comms::message::MessageTag; use tokio::time::delay_for; +const LOG_TARGET: &str = "mock::outbound_requester"; + /// Creates a mock outbound request "handler" for testing purposes. /// /// Each time a request is expected, handle_next should be called. @@ -163,6 +166,11 @@ impl OutboundServiceMock { while let Some(req) = self.receiver.next().await { match req { DhtOutboundRequest::SendMessage(params, body, reply_tx) => { + trace!( + target: LOG_TARGET, + "Send message request received with length of {} bytes", + body.len() + ); let behaviour = self.mock_state.get_behaviour(); match (*params).clone().broadcast_strategy {