From 6458a3d2c3acb4152369b9648c28702e7aa4f57d Mon Sep 17 00:00:00 2001 From: Philip Robinson Date: Tue, 9 Jun 2020 11:54:42 +0200 Subject: [PATCH] Add revalidation check to Output Manager service Outputs can become marked as invalid by the Output Manager service when they do not appear in the base nodes blockchain. Generally this means that they were re-orged out or have been spent by another copy of the wallet. However they can be reported as invalid incorrectly in some case like, for example, if the Base Node is not fully synced. This PR adds in an addition check that the Output Manager does on startup that will see if any of the invalid outputs has become valid again. If it has that output will changed back into a spendable output. --- .../src/output_manager_service/service.rs | 332 ++++++++++++------ .../storage/database.rs | 10 + .../storage/memory_db.rs | 17 + .../storage/sqlite_db.rs | 19 +- .../tests/output_manager_service/service.rs | 161 +++++++-- .../tests/output_manager_service/storage.rs | 35 ++ 6 files changed, 439 insertions(+), 135 deletions(-) diff --git a/base_layer/wallet/src/output_manager_service/service.rs b/base_layer/wallet/src/output_manager_service/service.rs index ef8a49dbff..80f0d3c237 100644 --- a/base_layer/wallet/src/output_manager_service/service.rs +++ b/base_layer/wallet/src/output_manager_service/service.rs @@ -99,6 +99,7 @@ where TBackend: OutputManagerBackend + 'static factories: CryptoFactories, base_node_public_key: Option, pending_utxo_query_keys: HashMap>>, + pending_revalidation_query_keys: HashMap>>, event_publisher: Publisher, } @@ -155,6 +156,7 @@ where factories, base_node_public_key: None, pending_utxo_query_keys: HashMap::new(), + pending_revalidation_query_keys: HashMap::new(), event_publisher, }) } @@ -290,7 +292,7 @@ where .await .map(|_| OutputManagerResponse::BaseNodePublicKeySet), OutputManagerRequest::SyncWithBaseNode => self - .query_unspent_outputs_status(utxo_query_timeout_futures, None) + .query_unspent_outputs_status(utxo_query_timeout_futures) .await .map(OutputManagerResponse::StartedBaseNodeSync), OutputManagerRequest::GetInvalidOutputs => { @@ -324,86 +326,154 @@ where }, }; - // Only process requests with a request_key that we are expecting. - let queried_hashes: Vec> = match self.pending_utxo_query_keys.remove(&request_key) { - None => { - trace!( - target: LOG_TARGET, - "Ignoring Base Node Response with unexpected request key ({}), it was not meant for this service.", - request_key - ); - return Ok(()); - }, - Some(qh) => qh, - }; + let mut unspent_query_handled = false; + let mut invalid_query_handled = false; - trace!( - target: LOG_TARGET, - "Handling a Base Node Response meant for this service" - ); + // Check if the received key is in the pending UTXO query list to be handled + if let Some(queried_hashes) = self.pending_utxo_query_keys.remove(&request_key) { + trace!( + target: LOG_TARGET, + "Handling a Base Node Response for a Unspent Outputs request ({})", + request_key + ); - // Construct a HashMap of all the unspent outputs - let unspent_outputs: Vec = self.db.get_unspent_outputs().await?; + // Construct a HashMap of all the unspent outputs + let unspent_outputs: Vec = self.db.get_unspent_outputs().await?; - let mut output_hashes = HashMap::new(); - for uo in unspent_outputs.iter() { - let hash = uo.hash.clone(); - if queried_hashes.iter().any(|h| &hash == h) { - output_hashes.insert(hash.clone(), uo.clone()); + let mut output_hashes = HashMap::new(); + for uo in unspent_outputs.iter() { + let hash = uo.hash.clone(); + if queried_hashes.iter().any(|h| &hash == h) { + output_hashes.insert(hash.clone(), uo.clone()); + } } - } - // Go through all the returned UTXOs and if they are in the hashmap remove them - for output in response.iter() { - let response_hash = TransactionOutput::try_from(output.clone()) - .map_err(OutputManagerError::ConversionError)? - .hash(); + // Go through all the returned UTXOs and if they are in the hashmap remove them + for output in response.iter() { + let response_hash = TransactionOutput::try_from(output.clone()) + .map_err(OutputManagerError::ConversionError)? + .hash(); - let _ = output_hashes.remove(&response_hash); - } + let _ = output_hashes.remove(&response_hash); + } - // If there are any remaining Unspent Outputs we will move them to the invalid collection - for (_k, v) in output_hashes { - // Get the transaction these belonged to so we can display the kernel signature of the transaction this - // output belonged to. + // If there are any remaining Unspent Outputs we will move them to the invalid collection + for (_k, v) in output_hashes { + // Get the transaction these belonged to so we can display the kernel signature of the transaction + // this output belonged to. - warn!( - target: LOG_TARGET, - "Output with value {} not returned from Base Node query and is thus being invalidated", - v.unblinded_output.value - ); - // If the output that is being invalidated has an associated TxId then get the kernel signature of the - // transaction and display for easier debugging - if let Some(tx_id) = self.db.invalidate_output(v).await? { - if let Ok(transaction) = self.transaction_service.get_completed_transaction(tx_id).await { + warn!( + target: LOG_TARGET, + "Output with value {} not returned from Base Node query and is thus being invalidated", + v.unblinded_output.value + ); + // If the output that is being invalidated has an associated TxId then get the kernel signature of + // the transaction and display for easier debugging + if let Some(tx_id) = self.db.invalidate_output(v).await? { + if let Ok(transaction) = self.transaction_service.get_completed_transaction(tx_id).await { + info!( + target: LOG_TARGET, + "Invalidated Output is from Transaction (TxId: {}) with message: {} and Kernel Signature: \ + {}", + transaction.tx_id, + transaction.message, + transaction.transaction.body.kernels()[0] + .excess_sig + .get_signature() + .to_hex() + ) + } + } else { info!( target: LOG_TARGET, - "Invalidated Output is from Transaction (TxId: {}) with message: {} and Kernel Signature: {}", - transaction.tx_id, - transaction.message, - transaction.transaction.body.kernels()[0] - .excess_sig - .get_signature() - .to_hex() - ) + "Invalidated Output does not have an associated TxId so it is likely a Coinbase output lost \ + to a Re-Org" + ); + } + } + unspent_query_handled = true; + debug!( + target: LOG_TARGET, + "Handled Base Node response for Unspent Outputs Query {}", request_key + ); + }; + + // Check if the received key is in the Invalid UTXO query list waiting to be handled + if self.pending_revalidation_query_keys.remove(&request_key).is_some() { + trace!( + target: LOG_TARGET, + "Handling a Base Node Response for a Invalid Outputs request ({})", + request_key + ); + let invalid_outputs = self.db.get_invalid_outputs().await?; + + for output in response.iter() { + let response_hash = TransactionOutput::try_from(output.clone()) + .map_err(OutputManagerError::ConversionError)? + .hash(); + + if let Some(output) = invalid_outputs.iter().find(|o| o.hash == response_hash) { + if self + .db + .revalidate_output(output.unblinded_output.spending_key.clone()) + .await + .is_ok() + { + trace!( + target: LOG_TARGET, + "Output with value {} has been restored to a valid spendable output", + output.unblinded_output.value + ); + } } - } else { - info!( - target: LOG_TARGET, - "Invalidated Output does not have an associated TxId so it is likely a Coinbase output lost to a \ - Re-Org" - ); } + invalid_query_handled = true; + debug!( + target: LOG_TARGET, + "Handled Base Node response for Invalid Outputs Query {}", request_key + ); } - debug!( - target: LOG_TARGET, - "Handled Base Node response for Query {}", request_key - ); + if unspent_query_handled || invalid_query_handled { + let _ = self + .event_publisher + .send(OutputManagerEvent::ReceiveBaseNodeResponse(request_key)) + .await + .map_err(|e| { + trace!( + target: LOG_TARGET, + "Error sending event, usually because there are no subscribers: {:?}", + e + ); + e + }); + } + + Ok(()) + } + + /// Handle the timeout of a pending UTXO query. + pub async fn handle_utxo_query_timeout( + &mut self, + query_key: u64, + utxo_query_timeout_futures: &mut FuturesUnordered>, + ) -> Result<(), OutputManagerError> + { + if let Some(hashes) = self.pending_utxo_query_keys.remove(&query_key) { + warn!(target: LOG_TARGET, "UTXO Unspent Outputs Query {} timed out", query_key); + self.query_outputs_status(utxo_query_timeout_futures, hashes, UtxoQueryType::UnspentOutputs) + .await?; + } + + if let Some(hashes) = self.pending_revalidation_query_keys.remove(&query_key) { + warn!(target: LOG_TARGET, "UTXO Invalid Outputs Query {} timed out", query_key); + self.query_outputs_status(utxo_query_timeout_futures, hashes, UtxoQueryType::InvalidOutputs) + .await?; + } let _ = self .event_publisher - .send(OutputManagerEvent::ReceiveBaseNodeResponse(request_key)) + .send(OutputManagerEvent::BaseNodeSyncRequestTimedOut(query_key)) .await .map_err(|e| { trace!( @@ -413,70 +483,82 @@ where ); e }); - Ok(()) } - /// Handle the timeout of a pending UTXO query. - pub async fn handle_utxo_query_timeout( + pub async fn query_unspent_outputs_status( &mut self, - query_key: u64, utxo_query_timeout_futures: &mut FuturesUnordered>, - ) -> Result<(), OutputManagerError> + ) -> Result { - if let Some(hashes) = self.pending_utxo_query_keys.remove(&query_key) { - warn!(target: LOG_TARGET, "UTXO Query {} timed out", query_key); - self.query_unspent_outputs_status(utxo_query_timeout_futures, Some(hashes)) - .await?; + let unspent_output_hashes = self + .db + .get_unspent_outputs() + .await? + .iter() + .map(|uo| uo.hash.clone()) + .collect(); + + let key = self + .query_outputs_status( + utxo_query_timeout_futures, + unspent_output_hashes, + UtxoQueryType::UnspentOutputs, + ) + .await?; - let _ = self - .event_publisher - .send(OutputManagerEvent::BaseNodeSyncRequestTimedOut(query_key)) - .await - .map_err(|e| { - trace!( - target: LOG_TARGET, - "Error sending event, usually because there are no subscribers: {:?}", - e - ); - e - }); - } - Ok(()) + Ok(key) + } + + pub async fn query_invalid_outputs_status( + &mut self, + utxo_query_timeout_futures: &mut FuturesUnordered>, + ) -> Result + { + let invalid_output_hashes: Vec<_> = self + .db + .get_invalid_outputs() + .await? + .into_iter() + .map(|uo| uo.hash) + .collect(); + + let key = if !invalid_output_hashes.is_empty() { + self.query_outputs_status( + utxo_query_timeout_futures, + invalid_output_hashes, + UtxoQueryType::InvalidOutputs, + ) + .await? + } else { + 0 + }; + + Ok(key) } - /// Send queries to the base node to check the status of all unspent outputs. If the outputs are no longer - /// available their status will be updated in the wallet. - async fn query_unspent_outputs_status( + /// Send queries to the base node to check the status of all specified outputs. + async fn query_outputs_status( &mut self, utxo_query_timeout_futures: &mut FuturesUnordered>, - specified_outputs: Option>>, + mut outputs_to_query: Vec>, + query_type: UtxoQueryType, ) -> Result { match self.base_node_public_key.as_ref() { None => Err(OutputManagerError::NoBaseNodeKeysProvided), Some(pk) => { let mut first_request_key = 0; - let mut unspent_outputs: Vec> = if let Some(hashes) = specified_outputs { - hashes - } else { - self.db - .get_unspent_outputs() - .await? - .iter() - .map(|uo| uo.hash.clone()) - .collect() - }; // Determine how many rounds of base node request we need to query all the outputs in batches of // max_utxo_query_size let rounds = - ((unspent_outputs.len() as f32) / (self.config.max_utxo_query_size as f32)).ceil() as usize; + ((outputs_to_query.len() as f32) / (self.config.max_utxo_query_size as f32 + 0.1)) as usize + 1; for r in 0..rounds { let mut output_hashes = Vec::new(); for uo_hash in - unspent_outputs.drain(..cmp::min(self.config.max_utxo_query_size, unspent_outputs.len())) + outputs_to_query.drain(..cmp::min(self.config.max_utxo_query_size, outputs_to_query.len())) { output_hashes.push(uo_hash); } @@ -509,14 +591,14 @@ where match send_message_response.resolve_ok().await { None => trace!( target: LOG_TARGET, - "Failed to send Output Manager UTXO Sync query ({}) to Base Node", + "Failed to send Output Manager UTXO query ({}) to Base Node", request_key ), Some(send_states) => { if send_states.len() == 1 { trace!( target: LOG_TARGET, - "Output Manager UTXO Sync query ({}) queued for sending with Message {}", + "Output Manager UTXO query ({}) queued for sending with Message {}", request_key, send_states[0].tag, ); @@ -524,7 +606,7 @@ where if send_states.wait_single().await { trace!( target: LOG_TARGET, - "Output Manager UTXO Sync query ({}) successfully sent to Base Node with \ + "Output Manager UTXO query ({}) successfully sent to Base Node with \ Message {}", request_key, message_tag, @@ -532,8 +614,8 @@ where } else { trace!( target: LOG_TARGET, - "Failed to send Output Manager UTXO Sync query ({}) to Base Node with \ - Message {}", + "Failed to send Output Manager UTXO query ({}) to Base Node with Message \ + {}", request_key, message_tag, ); @@ -541,7 +623,7 @@ where } else { trace!( target: LOG_TARGET, - "Failed to send Output Manager UTXO Sync query ({}) to Base Node", + "Failed to send Output Manager UTXO query ({}) to Base Node", request_key ) } @@ -549,12 +631,22 @@ where } }); - self.pending_utxo_query_keys.insert(request_key, output_hashes); + match query_type { + UtxoQueryType::UnspentOutputs => { + self.pending_utxo_query_keys.insert(request_key, output_hashes); + }, + UtxoQueryType::InvalidOutputs => { + self.pending_revalidation_query_keys.insert(request_key, output_hashes); + }, + } + let state_timeout = StateDelay::new(self.config.base_node_query_timeout, request_key); utxo_query_timeout_futures.push(state_timeout.delay().boxed()); + debug!( target: LOG_TARGET, - "Output Manager Sync query ({}) sent to Base Node, part {} of {} requests", + "Output Manager {} query ({}) sent to Base Node, part {} of {} requests", + query_type, request_key, r + 1, rounds @@ -866,8 +958,8 @@ where self.base_node_public_key = Some(base_node_public_key); if startup_query { - self.query_unspent_outputs_status(utxo_query_timeout_futures, None) - .await?; + self.query_unspent_outputs_status(utxo_query_timeout_futures).await?; + self.query_invalid_outputs_status(utxo_query_timeout_futures).await?; } Ok(()) } @@ -1026,3 +1118,17 @@ impl fmt::Display for Balance { Ok(()) } } + +enum UtxoQueryType { + UnspentOutputs, + InvalidOutputs, +} + +impl fmt::Display for UtxoQueryType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + UtxoQueryType::UnspentOutputs => write!(f, "Unspent Outputs"), + UtxoQueryType::InvalidOutputs => write!(f, "Invalid Outputs"), + } + } +} diff --git a/base_layer/wallet/src/output_manager_service/storage/database.rs b/base_layer/wallet/src/output_manager_service/storage/database.rs index 0c3a35be1e..eaea0f0b45 100644 --- a/base_layer/wallet/src/output_manager_service/storage/database.rs +++ b/base_layer/wallet/src/output_manager_service/storage/database.rs @@ -83,6 +83,8 @@ pub trait OutputManagerBackend: Send + Sync { /// If an unspent output is detected as invalid (i.e. not available on the blockchain) then it should be moved to /// the invalid outputs collection. The function will return the last recorded TxId associated with this output. fn invalidate_unspent_output(&self, output: &DbUnblindedOutput) -> Result, OutputManagerStorageError>; + /// If an invalid output is found to be valid this function will turn it back into an unspent output + fn revalidate_unspent_output(&self, spending_key: &BlindingFactor) -> Result<(), OutputManagerStorageError>; } /// Holds the outputs that have been selected for a given pending transaction waiting for confirmation @@ -502,6 +504,14 @@ where T: OutputManagerBackend + 'static .or_else(|err| Err(OutputManagerStorageError::BlockingTaskSpawnError(err.to_string()))) .and_then(|inner_result| inner_result) } + + pub async fn revalidate_output(&self, spending_key: BlindingFactor) -> Result<(), OutputManagerStorageError> { + let db_clone = self.db.clone(); + tokio::task::spawn_blocking(move || db_clone.revalidate_unspent_output(&spending_key)) + .await + .or_else(|err| Err(OutputManagerStorageError::BlockingTaskSpawnError(err.to_string()))) + .and_then(|inner_result| inner_result) + } } fn unexpected_result(req: DbKey, res: DbValue) -> Result { diff --git a/base_layer/wallet/src/output_manager_service/storage/memory_db.rs b/base_layer/wallet/src/output_manager_service/storage/memory_db.rs index 3b6c5fe2c0..969f7803bf 100644 --- a/base_layer/wallet/src/output_manager_service/storage/memory_db.rs +++ b/base_layer/wallet/src/output_manager_service/storage/memory_db.rs @@ -42,6 +42,7 @@ use std::{ sync::{Arc, RwLock}, time::Duration, }; +use tari_core::transactions::types::BlindingFactor; /// This structure is an In-Memory database backend that implements the `OutputManagerBackend` trait and provides all /// the functionality required by the trait. @@ -372,6 +373,22 @@ impl OutputManagerBackend for OutputManagerMemoryDatabase { None => Err(OutputManagerStorageError::ValuesNotFound), } } + + fn revalidate_unspent_output(&self, spending_key: &BlindingFactor) -> Result<(), OutputManagerStorageError> { + let mut db = acquire_write_lock!(self.db); + match db + .invalid_outputs + .iter() + .position(|v| v.output.unblinded_output.spending_key == *spending_key) + { + Some(pos) => { + let output = db.invalid_outputs.remove(pos); + db.unspent_outputs.push(output); + Ok(()) + }, + None => Err(OutputManagerStorageError::ValuesNotFound), + } + } } // A struct that contains the extra info we are using in the Sql version of this backend diff --git a/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs b/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs index c2ea85e7a6..363773a557 100644 --- a/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs +++ b/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs @@ -54,7 +54,7 @@ use tari_core::{ transactions::{ tari_amount::MicroTari, transaction::{OutputFeatures, OutputFlags, UnblindedOutput}, - types::{CryptoFactories, PrivateKey}, + types::{BlindingFactor, CryptoFactories, PrivateKey}, }, }; use tari_crypto::tari_utilities::ByteArray; @@ -429,6 +429,23 @@ impl OutputManagerBackend for OutputManagerSqliteDatabase { Ok(tx_id) } + + fn revalidate_unspent_output(&self, spending_key: &BlindingFactor) -> Result<(), OutputManagerStorageError> { + let conn = acquire_lock!(self.database_connection); + let output = OutputSql::find(&spending_key.to_vec(), &conn)?; + + if OutputStatus::try_from(output.status)? != OutputStatus::Invalid { + return Err(OutputManagerStorageError::ValuesNotFound); + } + let _ = output.update( + UpdateOutput { + status: Some(OutputStatus::Unspent), + tx_id: None, + }, + &(*conn), + )?; + Ok(()) + } } /// A utility function to construct a PendingTransactionOutputs structure for a TxId, set of Outputs and a Timestamp diff --git a/base_layer/wallet/tests/output_manager_service/service.rs b/base_layer/wallet/tests/output_manager_service/service.rs index 54436c0736..00f84bbb95 100644 --- a/base_layer/wallet/tests/output_manager_service/service.rs +++ b/base_layer/wallet/tests/output_manager_service/service.rs @@ -73,8 +73,9 @@ use tari_wallet::{ handle::{OutputManagerEvent, OutputManagerHandle}, service::OutputManagerService, storage::{ - database::{DbKey, DbValue, OutputManagerBackend, OutputManagerDatabase}, + database::{DbKey, DbKeyValuePair, DbValue, OutputManagerBackend, OutputManagerDatabase, WriteOperation}, memory_db::OutputManagerMemoryDatabase, + models::DbUnblindedOutput, sqlite_db::OutputManagerSqliteDatabase, }, }, @@ -652,15 +653,34 @@ fn test_startup_utxo_scan() { let factories = CryptoFactories::default(); let mut runtime = Runtime::new().unwrap(); + let backend = OutputManagerMemoryDatabase::new(); + + let invalid_key = PrivateKey::random(&mut OsRng); + let invalid_value = 666; + let invalid_output = UnblindedOutput::new(MicroTari::from(invalid_value), invalid_key.clone(), None); + let invalid_hash = invalid_output.as_transaction_output(&factories).unwrap().hash(); + + backend + .write(WriteOperation::Insert(DbKeyValuePair::UnspentOutput( + invalid_output.spending_key.clone(), + Box::new(DbUnblindedOutput::from_unblinded_output(invalid_output.clone(), &factories).unwrap()), + ))) + .unwrap(); + backend + .invalidate_unspent_output( + &DbUnblindedOutput::from_unblinded_output(invalid_output.clone(), &factories).unwrap(), + ) + .unwrap(); let (mut oms, outbound_service, _shutdown, mut base_node_response_sender, _) = - setup_output_manager_service(&mut runtime, OutputManagerMemoryDatabase::new()); + setup_output_manager_service(&mut runtime, backend); let mut hashes = Vec::new(); let key1 = PrivateKey::random(&mut OsRng); let value1 = 500; let output1 = UnblindedOutput::new(MicroTari::from(value1), key1.clone(), None); let tx_output1 = output1.as_transaction_output(&factories).unwrap(); - hashes.push(tx_output1.hash()); + let output1_hash = tx_output1.hash(); + hashes.push(output1_hash.clone()); runtime.block_on(oms.add_output(output1.clone())).unwrap(); let key2 = PrivateKey::random(&mut OsRng); @@ -679,6 +699,14 @@ fn test_startup_utxo_scan() { runtime.block_on(oms.add_output(output3.clone())).unwrap(); + let key4 = PrivateKey::random(&mut OsRng); + let value4 = 901; + let output4 = UnblindedOutput::new(MicroTari::from(value4), key4.clone(), None); + let tx_output4 = output4.as_transaction_output(&factories).unwrap(); + hashes.push(tx_output4.hash()); + + runtime.block_on(oms.add_output(output4.clone())).unwrap(); + let base_node_identity = NodeIdentity::random( &mut OsRng, "/ip4/127.0.0.1/tcp/58217".parse().unwrap(), @@ -691,8 +719,11 @@ fn test_startup_utxo_scan() { .unwrap(); outbound_service - .wait_call_count(2, Duration::from_secs(60)) + .wait_call_count(3, Duration::from_secs(60)) .expect("call wait 1"); + + let (_, _) = outbound_service.pop_call().unwrap(); // Burn the invalid request + let (_, body) = outbound_service.pop_call().unwrap(); let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let bn_request1: BaseNodeProto::BaseNodeServiceRequest = envelope_body @@ -734,13 +765,13 @@ fn test_startup_utxo_scan() { let result_stream = runtime.block_on(async { collect_stream!( oms.get_event_stream_fused().map(|i| (*i).clone()), - take = 2, + take = 3, timeout = Duration::from_secs(60) ) }); assert_eq!( - 2, + 3, result_stream.iter().fold(0, |acc, item| { if let OutputManagerEvent::BaseNodeSyncRequestTimedOut(_) = item { acc + 1 @@ -750,19 +781,88 @@ fn test_startup_utxo_scan() { }) ); + // Test the response to the revalidation call first so as not to confuse the invalidation that happens during the + // responses to the Unspent UTXO queries + let mut invalid_request_key = 0; + let mut unspent_request_key_with_output1 = 0; + + for _ in 0..3 { + let (_, body) = outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); + let bn_request: BaseNodeProto::BaseNodeServiceRequest = envelope_body + .decode_part::(1) + .unwrap() + .unwrap(); + + let request_hashes = if let Request::FetchUtxos(outputs) = bn_request.request.unwrap() { + outputs.outputs + } else { + assert!(false, "Wrong request type"); + Vec::new() + }; + + if request_hashes.iter().find(|i| **i == invalid_hash).is_some() { + invalid_request_key = bn_request.request_key; + } + if request_hashes.iter().find(|i| **i == output1_hash).is_some() { + unspent_request_key_with_output1 = bn_request.request_key; + } + } + assert_ne!(invalid_request_key, 0, "Should have found invalid request key"); + assert_ne!( + unspent_request_key_with_output1, 0, + "Should have found request key for request with output 1 in it" + ); + + let invalid_txs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); + assert_eq!(invalid_txs.len(), 1); + let base_node_response = BaseNodeProto::BaseNodeServiceResponse { + request_key: invalid_request_key, + response: Some(BaseNodeResponseProto::TransactionOutputs( + BaseNodeProto::TransactionOutputs { + outputs: vec![invalid_output.clone().as_transaction_output(&factories).unwrap().into()].into(), + }, + )), + }; + + runtime + .block_on(base_node_response_sender.send(create_dummy_message( + base_node_response, + base_node_identity.public_key(), + ))) + .unwrap(); + + let mut event_stream = oms.get_event_stream_fused(); + + runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut acc = 0; + loop { + futures::select! { + event = event_stream.select_next_some() => { + if let OutputManagerEvent::ReceiveBaseNodeResponse(_) = (*event).clone() { + acc += 1; + if acc >= 1 { + break; + } + } + }, + () = delay => { + break; + }, + } + } + assert!(acc >= 1, "Did not receive enough responses"); + }); + + let invalid_txs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); + assert_eq!(invalid_txs.len(), 0); + let key4 = PrivateKey::random(&mut OsRng); let value4 = 1000; let output4 = UnblindedOutput::new(MicroTari::from(value4), key4, None); runtime.block_on(oms.add_output(output4.clone())).unwrap(); - let (_, body) = outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); - let bn_request: BaseNodeProto::BaseNodeServiceRequest = envelope_body - .decode_part::(1) - .unwrap() - .unwrap(); - let _ = outbound_service.pop_call().unwrap(); - let invalid_txs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); assert_eq!(invalid_txs.len(), 0); @@ -786,7 +886,7 @@ fn test_startup_utxo_scan() { assert_eq!(invalid_txs.len(), 0); let base_node_response = BaseNodeProto::BaseNodeServiceResponse { - request_key: bn_request.request_key.clone(), + request_key: unspent_request_key_with_output1, response: Some(BaseNodeResponseProto::TransactionOutputs( BaseNodeProto::TransactionOutputs { outputs: vec![output1.clone().as_transaction_output(&factories).unwrap().into()].into(), @@ -801,8 +901,6 @@ fn test_startup_utxo_scan() { ))) .unwrap(); - let mut event_stream = oms.get_event_stream_fused(); - runtime.block_on(async { let mut delay = delay_for(Duration::from_secs(60)).fuse(); let mut acc = 0; @@ -828,11 +926,12 @@ fn test_startup_utxo_scan() { assert_eq!(invalid_outputs.len(), 1); let check2 = invalid_outputs[0] == output2; let check3 = invalid_outputs[0] == output3; + let check4 = invalid_outputs[0] == output4; - assert!(check2 || check3, "One of these outputs should be invalid"); + assert!(check2 || check3 || check4, "One of these outputs should be invalid"); let unspent_outputs = runtime.block_on(oms.get_unspent_outputs()).unwrap(); - assert_eq!(unspent_outputs.len(), 3); + assert_eq!(unspent_outputs.len(), 5); assert!(unspent_outputs.iter().find(|uo| uo == &&output1).is_some()); if check2 { assert!(unspent_outputs.iter().find(|uo| uo == &&output3).is_some()) @@ -856,6 +955,13 @@ fn test_startup_utxo_scan() { .unwrap() .unwrap(); + let (_, body) = outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); + let bn_request3: BaseNodeProto::BaseNodeServiceRequest = envelope_body + .decode_part::(1) + .unwrap() + .unwrap(); + let invalid_txs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); assert_eq!(invalid_txs.len(), 1); @@ -886,6 +992,19 @@ fn test_startup_utxo_scan() { ))) .unwrap(); + let base_node_response3 = BaseNodeProto::BaseNodeServiceResponse { + request_key: bn_request3.request_key.clone(), + response: Some(BaseNodeResponseProto::TransactionOutputs( + BaseNodeProto::TransactionOutputs { outputs: vec![].into() }, + )), + }; + runtime + .block_on(base_node_response_sender.send(create_dummy_message( + base_node_response3, + base_node_identity.public_key(), + ))) + .unwrap(); + let mut event_stream = oms.get_event_stream_fused(); runtime.block_on(async { @@ -896,7 +1015,7 @@ fn test_startup_utxo_scan() { event = event_stream.select_next_some() => { if let OutputManagerEvent::ReceiveBaseNodeResponse(r) = (*event).clone() { acc += 1; - if acc >= 3 { + if acc >= 4 { break; } } @@ -910,7 +1029,7 @@ fn test_startup_utxo_scan() { }); let invalid_txs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); - assert_eq!(invalid_txs.len(), 4); + assert_eq!(invalid_txs.len(), 6); } fn sending_transaction_with_short_term_clear(backend: T) { diff --git a/base_layer/wallet/tests/output_manager_service/storage.rs b/base_layer/wallet/tests/output_manager_service/storage.rs index d0092c38de..b3daf35d50 100644 --- a/base_layer/wallet/tests/output_manager_service/storage.rs +++ b/base_layer/wallet/tests/output_manager_service/storage.rs @@ -290,6 +290,7 @@ pub fn test_db_backend(backend: T) { .block_on(db.invalidate_output(unspent_outputs[0].clone())) .unwrap(); let invalid_outputs = runtime.block_on(db.get_invalid_outputs()).unwrap(); + assert_eq!(invalid_outputs.len(), 1); assert_eq!(invalid_outputs[0], unspent_outputs[0]); @@ -297,6 +298,40 @@ pub fn test_db_backend(backend: T) { .block_on(db.invalidate_output(pending_txs[0].outputs_to_be_received[0].clone())) .unwrap(); assert_eq!(tx_id, Some(pending_txs[0].tx_id)); + + // test revalidating output + let unspent_outputs = runtime.block_on(db.get_unspent_outputs()).unwrap(); + assert!( + unspent_outputs + .iter() + .find(|o| o.unblinded_output == invalid_outputs[0].unblinded_output) + .is_none(), + "Should not find ouput" + ); + + assert!(runtime + .block_on( + db.revalidate_output( + pending_txs[2].outputs_to_be_spent[0] + .unblinded_output + .spending_key + .clone() + ) + ) + .is_err()); + runtime + .block_on(db.revalidate_output(invalid_outputs[0].unblinded_output.spending_key.clone())) + .unwrap(); + let new_invalid_outputs = runtime.block_on(db.get_invalid_outputs()).unwrap(); + assert_eq!(new_invalid_outputs.len(), 1); + let unspent_outputs = runtime.block_on(db.get_unspent_outputs()).unwrap(); + assert!( + unspent_outputs + .iter() + .find(|o| o.unblinded_output == invalid_outputs[0].unblinded_output) + .is_some(), + "Should find revalidated ouput" + ); } #[test]