diff --git a/base_layer/wallet/src/output_manager_service/service.rs b/base_layer/wallet/src/output_manager_service/service.rs index ef8a49dbff..80f0d3c237 100644 --- a/base_layer/wallet/src/output_manager_service/service.rs +++ b/base_layer/wallet/src/output_manager_service/service.rs @@ -99,6 +99,7 @@ where TBackend: OutputManagerBackend + 'static factories: CryptoFactories, base_node_public_key: Option, pending_utxo_query_keys: HashMap>>, + pending_revalidation_query_keys: HashMap>>, event_publisher: Publisher, } @@ -155,6 +156,7 @@ where factories, base_node_public_key: None, pending_utxo_query_keys: HashMap::new(), + pending_revalidation_query_keys: HashMap::new(), event_publisher, }) } @@ -290,7 +292,7 @@ where .await .map(|_| OutputManagerResponse::BaseNodePublicKeySet), OutputManagerRequest::SyncWithBaseNode => self - .query_unspent_outputs_status(utxo_query_timeout_futures, None) + .query_unspent_outputs_status(utxo_query_timeout_futures) .await .map(OutputManagerResponse::StartedBaseNodeSync), OutputManagerRequest::GetInvalidOutputs => { @@ -324,86 +326,154 @@ where }, }; - // Only process requests with a request_key that we are expecting. - let queried_hashes: Vec> = match self.pending_utxo_query_keys.remove(&request_key) { - None => { - trace!( - target: LOG_TARGET, - "Ignoring Base Node Response with unexpected request key ({}), it was not meant for this service.", - request_key - ); - return Ok(()); - }, - Some(qh) => qh, - }; + let mut unspent_query_handled = false; + let mut invalid_query_handled = false; - trace!( - target: LOG_TARGET, - "Handling a Base Node Response meant for this service" - ); + // Check if the received key is in the pending UTXO query list to be handled + if let Some(queried_hashes) = self.pending_utxo_query_keys.remove(&request_key) { + trace!( + target: LOG_TARGET, + "Handling a Base Node Response for a Unspent Outputs request ({})", + request_key + ); - // Construct a HashMap of all the unspent outputs - let unspent_outputs: Vec = self.db.get_unspent_outputs().await?; + // Construct a HashMap of all the unspent outputs + let unspent_outputs: Vec = self.db.get_unspent_outputs().await?; - let mut output_hashes = HashMap::new(); - for uo in unspent_outputs.iter() { - let hash = uo.hash.clone(); - if queried_hashes.iter().any(|h| &hash == h) { - output_hashes.insert(hash.clone(), uo.clone()); + let mut output_hashes = HashMap::new(); + for uo in unspent_outputs.iter() { + let hash = uo.hash.clone(); + if queried_hashes.iter().any(|h| &hash == h) { + output_hashes.insert(hash.clone(), uo.clone()); + } } - } - // Go through all the returned UTXOs and if they are in the hashmap remove them - for output in response.iter() { - let response_hash = TransactionOutput::try_from(output.clone()) - .map_err(OutputManagerError::ConversionError)? - .hash(); + // Go through all the returned UTXOs and if they are in the hashmap remove them + for output in response.iter() { + let response_hash = TransactionOutput::try_from(output.clone()) + .map_err(OutputManagerError::ConversionError)? + .hash(); - let _ = output_hashes.remove(&response_hash); - } + let _ = output_hashes.remove(&response_hash); + } - // If there are any remaining Unspent Outputs we will move them to the invalid collection - for (_k, v) in output_hashes { - // Get the transaction these belonged to so we can display the kernel signature of the transaction this - // output belonged to. + // If there are any remaining Unspent Outputs we will move them to the invalid collection + for (_k, v) in output_hashes { + // Get the transaction these belonged to so we can display the kernel signature of the transaction + // this output belonged to. - warn!( - target: LOG_TARGET, - "Output with value {} not returned from Base Node query and is thus being invalidated", - v.unblinded_output.value - ); - // If the output that is being invalidated has an associated TxId then get the kernel signature of the - // transaction and display for easier debugging - if let Some(tx_id) = self.db.invalidate_output(v).await? { - if let Ok(transaction) = self.transaction_service.get_completed_transaction(tx_id).await { + warn!( + target: LOG_TARGET, + "Output with value {} not returned from Base Node query and is thus being invalidated", + v.unblinded_output.value + ); + // If the output that is being invalidated has an associated TxId then get the kernel signature of + // the transaction and display for easier debugging + if let Some(tx_id) = self.db.invalidate_output(v).await? { + if let Ok(transaction) = self.transaction_service.get_completed_transaction(tx_id).await { + info!( + target: LOG_TARGET, + "Invalidated Output is from Transaction (TxId: {}) with message: {} and Kernel Signature: \ + {}", + transaction.tx_id, + transaction.message, + transaction.transaction.body.kernels()[0] + .excess_sig + .get_signature() + .to_hex() + ) + } + } else { info!( target: LOG_TARGET, - "Invalidated Output is from Transaction (TxId: {}) with message: {} and Kernel Signature: {}", - transaction.tx_id, - transaction.message, - transaction.transaction.body.kernels()[0] - .excess_sig - .get_signature() - .to_hex() - ) + "Invalidated Output does not have an associated TxId so it is likely a Coinbase output lost \ + to a Re-Org" + ); + } + } + unspent_query_handled = true; + debug!( + target: LOG_TARGET, + "Handled Base Node response for Unspent Outputs Query {}", request_key + ); + }; + + // Check if the received key is in the Invalid UTXO query list waiting to be handled + if self.pending_revalidation_query_keys.remove(&request_key).is_some() { + trace!( + target: LOG_TARGET, + "Handling a Base Node Response for a Invalid Outputs request ({})", + request_key + ); + let invalid_outputs = self.db.get_invalid_outputs().await?; + + for output in response.iter() { + let response_hash = TransactionOutput::try_from(output.clone()) + .map_err(OutputManagerError::ConversionError)? + .hash(); + + if let Some(output) = invalid_outputs.iter().find(|o| o.hash == response_hash) { + if self + .db + .revalidate_output(output.unblinded_output.spending_key.clone()) + .await + .is_ok() + { + trace!( + target: LOG_TARGET, + "Output with value {} has been restored to a valid spendable output", + output.unblinded_output.value + ); + } } - } else { - info!( - target: LOG_TARGET, - "Invalidated Output does not have an associated TxId so it is likely a Coinbase output lost to a \ - Re-Org" - ); } + invalid_query_handled = true; + debug!( + target: LOG_TARGET, + "Handled Base Node response for Invalid Outputs Query {}", request_key + ); } - debug!( - target: LOG_TARGET, - "Handled Base Node response for Query {}", request_key - ); + if unspent_query_handled || invalid_query_handled { + let _ = self + .event_publisher + .send(OutputManagerEvent::ReceiveBaseNodeResponse(request_key)) + .await + .map_err(|e| { + trace!( + target: LOG_TARGET, + "Error sending event, usually because there are no subscribers: {:?}", + e + ); + e + }); + } + + Ok(()) + } + + /// Handle the timeout of a pending UTXO query. + pub async fn handle_utxo_query_timeout( + &mut self, + query_key: u64, + utxo_query_timeout_futures: &mut FuturesUnordered>, + ) -> Result<(), OutputManagerError> + { + if let Some(hashes) = self.pending_utxo_query_keys.remove(&query_key) { + warn!(target: LOG_TARGET, "UTXO Unspent Outputs Query {} timed out", query_key); + self.query_outputs_status(utxo_query_timeout_futures, hashes, UtxoQueryType::UnspentOutputs) + .await?; + } + + if let Some(hashes) = self.pending_revalidation_query_keys.remove(&query_key) { + warn!(target: LOG_TARGET, "UTXO Invalid Outputs Query {} timed out", query_key); + self.query_outputs_status(utxo_query_timeout_futures, hashes, UtxoQueryType::InvalidOutputs) + .await?; + } let _ = self .event_publisher - .send(OutputManagerEvent::ReceiveBaseNodeResponse(request_key)) + .send(OutputManagerEvent::BaseNodeSyncRequestTimedOut(query_key)) .await .map_err(|e| { trace!( @@ -413,70 +483,82 @@ where ); e }); - Ok(()) } - /// Handle the timeout of a pending UTXO query. - pub async fn handle_utxo_query_timeout( + pub async fn query_unspent_outputs_status( &mut self, - query_key: u64, utxo_query_timeout_futures: &mut FuturesUnordered>, - ) -> Result<(), OutputManagerError> + ) -> Result { - if let Some(hashes) = self.pending_utxo_query_keys.remove(&query_key) { - warn!(target: LOG_TARGET, "UTXO Query {} timed out", query_key); - self.query_unspent_outputs_status(utxo_query_timeout_futures, Some(hashes)) - .await?; + let unspent_output_hashes = self + .db + .get_unspent_outputs() + .await? + .iter() + .map(|uo| uo.hash.clone()) + .collect(); + + let key = self + .query_outputs_status( + utxo_query_timeout_futures, + unspent_output_hashes, + UtxoQueryType::UnspentOutputs, + ) + .await?; - let _ = self - .event_publisher - .send(OutputManagerEvent::BaseNodeSyncRequestTimedOut(query_key)) - .await - .map_err(|e| { - trace!( - target: LOG_TARGET, - "Error sending event, usually because there are no subscribers: {:?}", - e - ); - e - }); - } - Ok(()) + Ok(key) + } + + pub async fn query_invalid_outputs_status( + &mut self, + utxo_query_timeout_futures: &mut FuturesUnordered>, + ) -> Result + { + let invalid_output_hashes: Vec<_> = self + .db + .get_invalid_outputs() + .await? + .into_iter() + .map(|uo| uo.hash) + .collect(); + + let key = if !invalid_output_hashes.is_empty() { + self.query_outputs_status( + utxo_query_timeout_futures, + invalid_output_hashes, + UtxoQueryType::InvalidOutputs, + ) + .await? + } else { + 0 + }; + + Ok(key) } - /// Send queries to the base node to check the status of all unspent outputs. If the outputs are no longer - /// available their status will be updated in the wallet. - async fn query_unspent_outputs_status( + /// Send queries to the base node to check the status of all specified outputs. + async fn query_outputs_status( &mut self, utxo_query_timeout_futures: &mut FuturesUnordered>, - specified_outputs: Option>>, + mut outputs_to_query: Vec>, + query_type: UtxoQueryType, ) -> Result { match self.base_node_public_key.as_ref() { None => Err(OutputManagerError::NoBaseNodeKeysProvided), Some(pk) => { let mut first_request_key = 0; - let mut unspent_outputs: Vec> = if let Some(hashes) = specified_outputs { - hashes - } else { - self.db - .get_unspent_outputs() - .await? - .iter() - .map(|uo| uo.hash.clone()) - .collect() - }; // Determine how many rounds of base node request we need to query all the outputs in batches of // max_utxo_query_size let rounds = - ((unspent_outputs.len() as f32) / (self.config.max_utxo_query_size as f32)).ceil() as usize; + ((outputs_to_query.len() as f32) / (self.config.max_utxo_query_size as f32 + 0.1)) as usize + 1; for r in 0..rounds { let mut output_hashes = Vec::new(); for uo_hash in - unspent_outputs.drain(..cmp::min(self.config.max_utxo_query_size, unspent_outputs.len())) + outputs_to_query.drain(..cmp::min(self.config.max_utxo_query_size, outputs_to_query.len())) { output_hashes.push(uo_hash); } @@ -509,14 +591,14 @@ where match send_message_response.resolve_ok().await { None => trace!( target: LOG_TARGET, - "Failed to send Output Manager UTXO Sync query ({}) to Base Node", + "Failed to send Output Manager UTXO query ({}) to Base Node", request_key ), Some(send_states) => { if send_states.len() == 1 { trace!( target: LOG_TARGET, - "Output Manager UTXO Sync query ({}) queued for sending with Message {}", + "Output Manager UTXO query ({}) queued for sending with Message {}", request_key, send_states[0].tag, ); @@ -524,7 +606,7 @@ where if send_states.wait_single().await { trace!( target: LOG_TARGET, - "Output Manager UTXO Sync query ({}) successfully sent to Base Node with \ + "Output Manager UTXO query ({}) successfully sent to Base Node with \ Message {}", request_key, message_tag, @@ -532,8 +614,8 @@ where } else { trace!( target: LOG_TARGET, - "Failed to send Output Manager UTXO Sync query ({}) to Base Node with \ - Message {}", + "Failed to send Output Manager UTXO query ({}) to Base Node with Message \ + {}", request_key, message_tag, ); @@ -541,7 +623,7 @@ where } else { trace!( target: LOG_TARGET, - "Failed to send Output Manager UTXO Sync query ({}) to Base Node", + "Failed to send Output Manager UTXO query ({}) to Base Node", request_key ) } @@ -549,12 +631,22 @@ where } }); - self.pending_utxo_query_keys.insert(request_key, output_hashes); + match query_type { + UtxoQueryType::UnspentOutputs => { + self.pending_utxo_query_keys.insert(request_key, output_hashes); + }, + UtxoQueryType::InvalidOutputs => { + self.pending_revalidation_query_keys.insert(request_key, output_hashes); + }, + } + let state_timeout = StateDelay::new(self.config.base_node_query_timeout, request_key); utxo_query_timeout_futures.push(state_timeout.delay().boxed()); + debug!( target: LOG_TARGET, - "Output Manager Sync query ({}) sent to Base Node, part {} of {} requests", + "Output Manager {} query ({}) sent to Base Node, part {} of {} requests", + query_type, request_key, r + 1, rounds @@ -866,8 +958,8 @@ where self.base_node_public_key = Some(base_node_public_key); if startup_query { - self.query_unspent_outputs_status(utxo_query_timeout_futures, None) - .await?; + self.query_unspent_outputs_status(utxo_query_timeout_futures).await?; + self.query_invalid_outputs_status(utxo_query_timeout_futures).await?; } Ok(()) } @@ -1026,3 +1118,17 @@ impl fmt::Display for Balance { Ok(()) } } + +enum UtxoQueryType { + UnspentOutputs, + InvalidOutputs, +} + +impl fmt::Display for UtxoQueryType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + UtxoQueryType::UnspentOutputs => write!(f, "Unspent Outputs"), + UtxoQueryType::InvalidOutputs => write!(f, "Invalid Outputs"), + } + } +} diff --git a/base_layer/wallet/src/output_manager_service/storage/database.rs b/base_layer/wallet/src/output_manager_service/storage/database.rs index 0c3a35be1e..eaea0f0b45 100644 --- a/base_layer/wallet/src/output_manager_service/storage/database.rs +++ b/base_layer/wallet/src/output_manager_service/storage/database.rs @@ -83,6 +83,8 @@ pub trait OutputManagerBackend: Send + Sync { /// If an unspent output is detected as invalid (i.e. not available on the blockchain) then it should be moved to /// the invalid outputs collection. The function will return the last recorded TxId associated with this output. fn invalidate_unspent_output(&self, output: &DbUnblindedOutput) -> Result, OutputManagerStorageError>; + /// If an invalid output is found to be valid this function will turn it back into an unspent output + fn revalidate_unspent_output(&self, spending_key: &BlindingFactor) -> Result<(), OutputManagerStorageError>; } /// Holds the outputs that have been selected for a given pending transaction waiting for confirmation @@ -502,6 +504,14 @@ where T: OutputManagerBackend + 'static .or_else(|err| Err(OutputManagerStorageError::BlockingTaskSpawnError(err.to_string()))) .and_then(|inner_result| inner_result) } + + pub async fn revalidate_output(&self, spending_key: BlindingFactor) -> Result<(), OutputManagerStorageError> { + let db_clone = self.db.clone(); + tokio::task::spawn_blocking(move || db_clone.revalidate_unspent_output(&spending_key)) + .await + .or_else(|err| Err(OutputManagerStorageError::BlockingTaskSpawnError(err.to_string()))) + .and_then(|inner_result| inner_result) + } } fn unexpected_result(req: DbKey, res: DbValue) -> Result { diff --git a/base_layer/wallet/src/output_manager_service/storage/memory_db.rs b/base_layer/wallet/src/output_manager_service/storage/memory_db.rs index 3b6c5fe2c0..969f7803bf 100644 --- a/base_layer/wallet/src/output_manager_service/storage/memory_db.rs +++ b/base_layer/wallet/src/output_manager_service/storage/memory_db.rs @@ -42,6 +42,7 @@ use std::{ sync::{Arc, RwLock}, time::Duration, }; +use tari_core::transactions::types::BlindingFactor; /// This structure is an In-Memory database backend that implements the `OutputManagerBackend` trait and provides all /// the functionality required by the trait. @@ -372,6 +373,22 @@ impl OutputManagerBackend for OutputManagerMemoryDatabase { None => Err(OutputManagerStorageError::ValuesNotFound), } } + + fn revalidate_unspent_output(&self, spending_key: &BlindingFactor) -> Result<(), OutputManagerStorageError> { + let mut db = acquire_write_lock!(self.db); + match db + .invalid_outputs + .iter() + .position(|v| v.output.unblinded_output.spending_key == *spending_key) + { + Some(pos) => { + let output = db.invalid_outputs.remove(pos); + db.unspent_outputs.push(output); + Ok(()) + }, + None => Err(OutputManagerStorageError::ValuesNotFound), + } + } } // A struct that contains the extra info we are using in the Sql version of this backend diff --git a/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs b/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs index c2ea85e7a6..363773a557 100644 --- a/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs +++ b/base_layer/wallet/src/output_manager_service/storage/sqlite_db.rs @@ -54,7 +54,7 @@ use tari_core::{ transactions::{ tari_amount::MicroTari, transaction::{OutputFeatures, OutputFlags, UnblindedOutput}, - types::{CryptoFactories, PrivateKey}, + types::{BlindingFactor, CryptoFactories, PrivateKey}, }, }; use tari_crypto::tari_utilities::ByteArray; @@ -429,6 +429,23 @@ impl OutputManagerBackend for OutputManagerSqliteDatabase { Ok(tx_id) } + + fn revalidate_unspent_output(&self, spending_key: &BlindingFactor) -> Result<(), OutputManagerStorageError> { + let conn = acquire_lock!(self.database_connection); + let output = OutputSql::find(&spending_key.to_vec(), &conn)?; + + if OutputStatus::try_from(output.status)? != OutputStatus::Invalid { + return Err(OutputManagerStorageError::ValuesNotFound); + } + let _ = output.update( + UpdateOutput { + status: Some(OutputStatus::Unspent), + tx_id: None, + }, + &(*conn), + )?; + Ok(()) + } } /// A utility function to construct a PendingTransactionOutputs structure for a TxId, set of Outputs and a Timestamp diff --git a/base_layer/wallet/tests/output_manager_service/service.rs b/base_layer/wallet/tests/output_manager_service/service.rs index 54436c0736..00f84bbb95 100644 --- a/base_layer/wallet/tests/output_manager_service/service.rs +++ b/base_layer/wallet/tests/output_manager_service/service.rs @@ -73,8 +73,9 @@ use tari_wallet::{ handle::{OutputManagerEvent, OutputManagerHandle}, service::OutputManagerService, storage::{ - database::{DbKey, DbValue, OutputManagerBackend, OutputManagerDatabase}, + database::{DbKey, DbKeyValuePair, DbValue, OutputManagerBackend, OutputManagerDatabase, WriteOperation}, memory_db::OutputManagerMemoryDatabase, + models::DbUnblindedOutput, sqlite_db::OutputManagerSqliteDatabase, }, }, @@ -652,15 +653,34 @@ fn test_startup_utxo_scan() { let factories = CryptoFactories::default(); let mut runtime = Runtime::new().unwrap(); + let backend = OutputManagerMemoryDatabase::new(); + + let invalid_key = PrivateKey::random(&mut OsRng); + let invalid_value = 666; + let invalid_output = UnblindedOutput::new(MicroTari::from(invalid_value), invalid_key.clone(), None); + let invalid_hash = invalid_output.as_transaction_output(&factories).unwrap().hash(); + + backend + .write(WriteOperation::Insert(DbKeyValuePair::UnspentOutput( + invalid_output.spending_key.clone(), + Box::new(DbUnblindedOutput::from_unblinded_output(invalid_output.clone(), &factories).unwrap()), + ))) + .unwrap(); + backend + .invalidate_unspent_output( + &DbUnblindedOutput::from_unblinded_output(invalid_output.clone(), &factories).unwrap(), + ) + .unwrap(); let (mut oms, outbound_service, _shutdown, mut base_node_response_sender, _) = - setup_output_manager_service(&mut runtime, OutputManagerMemoryDatabase::new()); + setup_output_manager_service(&mut runtime, backend); let mut hashes = Vec::new(); let key1 = PrivateKey::random(&mut OsRng); let value1 = 500; let output1 = UnblindedOutput::new(MicroTari::from(value1), key1.clone(), None); let tx_output1 = output1.as_transaction_output(&factories).unwrap(); - hashes.push(tx_output1.hash()); + let output1_hash = tx_output1.hash(); + hashes.push(output1_hash.clone()); runtime.block_on(oms.add_output(output1.clone())).unwrap(); let key2 = PrivateKey::random(&mut OsRng); @@ -679,6 +699,14 @@ fn test_startup_utxo_scan() { runtime.block_on(oms.add_output(output3.clone())).unwrap(); + let key4 = PrivateKey::random(&mut OsRng); + let value4 = 901; + let output4 = UnblindedOutput::new(MicroTari::from(value4), key4.clone(), None); + let tx_output4 = output4.as_transaction_output(&factories).unwrap(); + hashes.push(tx_output4.hash()); + + runtime.block_on(oms.add_output(output4.clone())).unwrap(); + let base_node_identity = NodeIdentity::random( &mut OsRng, "/ip4/127.0.0.1/tcp/58217".parse().unwrap(), @@ -691,8 +719,11 @@ fn test_startup_utxo_scan() { .unwrap(); outbound_service - .wait_call_count(2, Duration::from_secs(60)) + .wait_call_count(3, Duration::from_secs(60)) .expect("call wait 1"); + + let (_, _) = outbound_service.pop_call().unwrap(); // Burn the invalid request + let (_, body) = outbound_service.pop_call().unwrap(); let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); let bn_request1: BaseNodeProto::BaseNodeServiceRequest = envelope_body @@ -734,13 +765,13 @@ fn test_startup_utxo_scan() { let result_stream = runtime.block_on(async { collect_stream!( oms.get_event_stream_fused().map(|i| (*i).clone()), - take = 2, + take = 3, timeout = Duration::from_secs(60) ) }); assert_eq!( - 2, + 3, result_stream.iter().fold(0, |acc, item| { if let OutputManagerEvent::BaseNodeSyncRequestTimedOut(_) = item { acc + 1 @@ -750,19 +781,88 @@ fn test_startup_utxo_scan() { }) ); + // Test the response to the revalidation call first so as not to confuse the invalidation that happens during the + // responses to the Unspent UTXO queries + let mut invalid_request_key = 0; + let mut unspent_request_key_with_output1 = 0; + + for _ in 0..3 { + let (_, body) = outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); + let bn_request: BaseNodeProto::BaseNodeServiceRequest = envelope_body + .decode_part::(1) + .unwrap() + .unwrap(); + + let request_hashes = if let Request::FetchUtxos(outputs) = bn_request.request.unwrap() { + outputs.outputs + } else { + assert!(false, "Wrong request type"); + Vec::new() + }; + + if request_hashes.iter().find(|i| **i == invalid_hash).is_some() { + invalid_request_key = bn_request.request_key; + } + if request_hashes.iter().find(|i| **i == output1_hash).is_some() { + unspent_request_key_with_output1 = bn_request.request_key; + } + } + assert_ne!(invalid_request_key, 0, "Should have found invalid request key"); + assert_ne!( + unspent_request_key_with_output1, 0, + "Should have found request key for request with output 1 in it" + ); + + let invalid_txs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); + assert_eq!(invalid_txs.len(), 1); + let base_node_response = BaseNodeProto::BaseNodeServiceResponse { + request_key: invalid_request_key, + response: Some(BaseNodeResponseProto::TransactionOutputs( + BaseNodeProto::TransactionOutputs { + outputs: vec![invalid_output.clone().as_transaction_output(&factories).unwrap().into()].into(), + }, + )), + }; + + runtime + .block_on(base_node_response_sender.send(create_dummy_message( + base_node_response, + base_node_identity.public_key(), + ))) + .unwrap(); + + let mut event_stream = oms.get_event_stream_fused(); + + runtime.block_on(async { + let mut delay = delay_for(Duration::from_secs(60)).fuse(); + let mut acc = 0; + loop { + futures::select! { + event = event_stream.select_next_some() => { + if let OutputManagerEvent::ReceiveBaseNodeResponse(_) = (*event).clone() { + acc += 1; + if acc >= 1 { + break; + } + } + }, + () = delay => { + break; + }, + } + } + assert!(acc >= 1, "Did not receive enough responses"); + }); + + let invalid_txs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); + assert_eq!(invalid_txs.len(), 0); + let key4 = PrivateKey::random(&mut OsRng); let value4 = 1000; let output4 = UnblindedOutput::new(MicroTari::from(value4), key4, None); runtime.block_on(oms.add_output(output4.clone())).unwrap(); - let (_, body) = outbound_service.pop_call().unwrap(); - let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); - let bn_request: BaseNodeProto::BaseNodeServiceRequest = envelope_body - .decode_part::(1) - .unwrap() - .unwrap(); - let _ = outbound_service.pop_call().unwrap(); - let invalid_txs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); assert_eq!(invalid_txs.len(), 0); @@ -786,7 +886,7 @@ fn test_startup_utxo_scan() { assert_eq!(invalid_txs.len(), 0); let base_node_response = BaseNodeProto::BaseNodeServiceResponse { - request_key: bn_request.request_key.clone(), + request_key: unspent_request_key_with_output1, response: Some(BaseNodeResponseProto::TransactionOutputs( BaseNodeProto::TransactionOutputs { outputs: vec![output1.clone().as_transaction_output(&factories).unwrap().into()].into(), @@ -801,8 +901,6 @@ fn test_startup_utxo_scan() { ))) .unwrap(); - let mut event_stream = oms.get_event_stream_fused(); - runtime.block_on(async { let mut delay = delay_for(Duration::from_secs(60)).fuse(); let mut acc = 0; @@ -828,11 +926,12 @@ fn test_startup_utxo_scan() { assert_eq!(invalid_outputs.len(), 1); let check2 = invalid_outputs[0] == output2; let check3 = invalid_outputs[0] == output3; + let check4 = invalid_outputs[0] == output4; - assert!(check2 || check3, "One of these outputs should be invalid"); + assert!(check2 || check3 || check4, "One of these outputs should be invalid"); let unspent_outputs = runtime.block_on(oms.get_unspent_outputs()).unwrap(); - assert_eq!(unspent_outputs.len(), 3); + assert_eq!(unspent_outputs.len(), 5); assert!(unspent_outputs.iter().find(|uo| uo == &&output1).is_some()); if check2 { assert!(unspent_outputs.iter().find(|uo| uo == &&output3).is_some()) @@ -856,6 +955,13 @@ fn test_startup_utxo_scan() { .unwrap() .unwrap(); + let (_, body) = outbound_service.pop_call().unwrap(); + let envelope_body = EnvelopeBody::decode(body.to_vec().as_slice()).unwrap(); + let bn_request3: BaseNodeProto::BaseNodeServiceRequest = envelope_body + .decode_part::(1) + .unwrap() + .unwrap(); + let invalid_txs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); assert_eq!(invalid_txs.len(), 1); @@ -886,6 +992,19 @@ fn test_startup_utxo_scan() { ))) .unwrap(); + let base_node_response3 = BaseNodeProto::BaseNodeServiceResponse { + request_key: bn_request3.request_key.clone(), + response: Some(BaseNodeResponseProto::TransactionOutputs( + BaseNodeProto::TransactionOutputs { outputs: vec![].into() }, + )), + }; + runtime + .block_on(base_node_response_sender.send(create_dummy_message( + base_node_response3, + base_node_identity.public_key(), + ))) + .unwrap(); + let mut event_stream = oms.get_event_stream_fused(); runtime.block_on(async { @@ -896,7 +1015,7 @@ fn test_startup_utxo_scan() { event = event_stream.select_next_some() => { if let OutputManagerEvent::ReceiveBaseNodeResponse(r) = (*event).clone() { acc += 1; - if acc >= 3 { + if acc >= 4 { break; } } @@ -910,7 +1029,7 @@ fn test_startup_utxo_scan() { }); let invalid_txs = runtime.block_on(oms.get_invalid_outputs()).unwrap(); - assert_eq!(invalid_txs.len(), 4); + assert_eq!(invalid_txs.len(), 6); } fn sending_transaction_with_short_term_clear(backend: T) { diff --git a/base_layer/wallet/tests/output_manager_service/storage.rs b/base_layer/wallet/tests/output_manager_service/storage.rs index d0092c38de..b3daf35d50 100644 --- a/base_layer/wallet/tests/output_manager_service/storage.rs +++ b/base_layer/wallet/tests/output_manager_service/storage.rs @@ -290,6 +290,7 @@ pub fn test_db_backend(backend: T) { .block_on(db.invalidate_output(unspent_outputs[0].clone())) .unwrap(); let invalid_outputs = runtime.block_on(db.get_invalid_outputs()).unwrap(); + assert_eq!(invalid_outputs.len(), 1); assert_eq!(invalid_outputs[0], unspent_outputs[0]); @@ -297,6 +298,40 @@ pub fn test_db_backend(backend: T) { .block_on(db.invalidate_output(pending_txs[0].outputs_to_be_received[0].clone())) .unwrap(); assert_eq!(tx_id, Some(pending_txs[0].tx_id)); + + // test revalidating output + let unspent_outputs = runtime.block_on(db.get_unspent_outputs()).unwrap(); + assert!( + unspent_outputs + .iter() + .find(|o| o.unblinded_output == invalid_outputs[0].unblinded_output) + .is_none(), + "Should not find ouput" + ); + + assert!(runtime + .block_on( + db.revalidate_output( + pending_txs[2].outputs_to_be_spent[0] + .unblinded_output + .spending_key + .clone() + ) + ) + .is_err()); + runtime + .block_on(db.revalidate_output(invalid_outputs[0].unblinded_output.spending_key.clone())) + .unwrap(); + let new_invalid_outputs = runtime.block_on(db.get_invalid_outputs()).unwrap(); + assert_eq!(new_invalid_outputs.len(), 1); + let unspent_outputs = runtime.block_on(db.get_unspent_outputs()).unwrap(); + assert!( + unspent_outputs + .iter() + .find(|o| o.unblinded_output == invalid_outputs[0].unblinded_output) + .is_some(), + "Should find revalidated ouput" + ); } #[test]