diff --git a/base_layer/wallet/src/connectivity_service/mock.rs b/base_layer/wallet/src/connectivity_service/mock.rs index 54f0295421..11228b6661 100644 --- a/base_layer/wallet/src/connectivity_service/mock.rs +++ b/base_layer/wallet/src/connectivity_service/mock.rs @@ -69,6 +69,11 @@ impl WalletConnectivityMock { self.base_node_watch.send(Some(base_node_peer)); } + pub async fn base_node_changed(&mut self) -> Option { + self.base_node_watch.changed().await; + self.base_node_watch.borrow().as_ref().cloned() + } + pub fn send_shutdown(&self) { self.base_node_wallet_rpc_client.send(None); self.base_node_sync_rpc_client.send(None); diff --git a/base_layer/wallet/src/output_manager_service/error.rs b/base_layer/wallet/src/output_manager_service/error.rs index 8adcc975af..293a9ac7c1 100644 --- a/base_layer/wallet/src/output_manager_service/error.rs +++ b/base_layer/wallet/src/output_manager_service/error.rs @@ -94,6 +94,8 @@ pub enum OutputManagerError { ServiceError(String), #[error("Base node is not synced")] BaseNodeNotSynced, + #[error("Base node changed")] + BaseNodeChanged, #[error("Invalid Sender Message Type")] InvalidSenderMessage, #[error("Coinbase build error: `{0}`")] diff --git a/base_layer/wallet/src/output_manager_service/service.rs b/base_layer/wallet/src/output_manager_service/service.rs index 258b9f60dc..d5bf4dd067 100644 --- a/base_layer/wallet/src/output_manager_service/service.rs +++ b/base_layer/wallet/src/output_manager_service/service.rs @@ -534,11 +534,13 @@ where } fn validate_outputs(&mut self) -> Result { - if !self.resources.connectivity.is_base_node_set() { - return Err(OutputManagerError::NoBaseNodeKeysProvided); - } + let current_base_node = self + .resources + .connectivity + .get_current_base_node_id() + .ok_or(OutputManagerError::NoBaseNodeKeysProvided)?; let id = OsRng.next_u64(); - let utxo_validation = TxoValidationTask::new( + let txo_validation = TxoValidationTask::new( id, self.resources.db.clone(), self.resources.connectivity.clone(), @@ -546,28 +548,56 @@ where self.resources.config.clone(), ); - let shutdown = self.resources.shutdown_signal.clone(); + let mut shutdown = self.resources.shutdown_signal.clone(); + let mut base_node_watch = self.resources.connectivity.get_current_base_node_watcher(); let event_publisher = self.resources.event_publisher.clone(); tokio::spawn(async move { - match utxo_validation.execute(shutdown).await { - Ok(id) => { - info!( - target: LOG_TARGET, - "UTXO Validation Protocol (Id: {}) completed successfully", id - ); - }, - Err(OutputManagerProtocolError { id, error }) => { - warn!( - target: LOG_TARGET, - "Error completing UTXO Validation Protocol (Id: {}): {:?}", id, error - ); - if let Err(e) = event_publisher.send(Arc::new(OutputManagerEvent::TxoValidationFailure(id))) { - debug!( - target: LOG_TARGET, - "Error sending event because there are no subscribers: {:?}", e - ); + let exec_fut = txo_validation.execute(); + tokio::pin!(exec_fut); + loop { + tokio::select! { + result = &mut exec_fut => { + match result { + Ok(id) => { + info!( + target: LOG_TARGET, + "UTXO Validation Protocol (Id: {}) completed successfully", id + ); + return; + }, + Err(OutputManagerProtocolError { id, error }) => { + warn!( + target: LOG_TARGET, + "Error completing UTXO Validation Protocol (Id: {}): {:?}", id, error + ); + if let Err(e) = event_publisher.send(Arc::new(OutputManagerEvent::TxoValidationFailure(id))) { + debug!( + target: LOG_TARGET, + "Error sending event because there are no subscribers: {:?}", e + ); + } + + return; + }, + } + }, + _ = shutdown.wait() => { + debug!(target: LOG_TARGET, "TXO Validation Protocol (Id: {}) shutting down because the system is shutting down", id); + return; + }, + _ = base_node_watch.changed() => { + if let Some(peer) = base_node_watch.borrow().as_ref() { + if peer.node_id != current_base_node { + debug!( + target: LOG_TARGET, + "TXO Validation Protocol (Id: {}) cancelled because base node changed", id + ); + return; + } + } + } - }, + } } }); Ok(id) diff --git a/base_layer/wallet/src/output_manager_service/tasks/txo_validation_task.rs b/base_layer/wallet/src/output_manager_service/tasks/txo_validation_task.rs index d66589fd6a..0e90112dcb 100644 --- a/base_layer/wallet/src/output_manager_service/tasks/txo_validation_task.rs +++ b/base_layer/wallet/src/output_manager_service/tasks/txo_validation_task.rs @@ -27,14 +27,14 @@ use std::{ use log::*; use tari_common_types::types::{BlockHash, FixedHash}; -use tari_comms::protocol::rpc::RpcError::RequestFailed; +use tari_comms::{peer_manager::Peer, protocol::rpc::RpcError::RequestFailed}; use tari_core::{ base_node::rpc::BaseNodeWalletRpcClient, blocks::BlockHeader, proto::base_node::{QueryDeletedRequest, UtxoQueryRequest}, }; -use tari_shutdown::ShutdownSignal; use tari_utilities::hex::Hex; +use tokio::sync::watch; use crate::{ connectivity_service::WalletConnectivityInterface, @@ -54,6 +54,7 @@ const LOG_TARGET: &str = "wallet::output_service::txo_validation_task"; pub struct TxoValidationTask { operation_id: u64, db: OutputManagerDatabase, + base_node_watch: watch::Receiver>, connectivity: TWalletConnectivity, event_publisher: OutputManagerEventSender, config: OutputManagerServiceConfig, @@ -74,13 +75,14 @@ where Self { operation_id, db, + base_node_watch: connectivity.get_current_base_node_watcher(), connectivity, event_publisher, config, } } - pub async fn execute(mut self, _shutdown: ShutdownSignal) -> Result { + pub async fn execute(mut self) -> Result { let mut base_node_client = self .connectivity .obtain_base_node_wallet_rpc_client() @@ -88,9 +90,15 @@ where .ok_or(OutputManagerError::Shutdown) .for_protocol(self.operation_id)?; + let base_node_peer = self + .base_node_watch + .borrow() + .as_ref() + .map(|p| p.node_id.clone()) + .ok_or_else(|| OutputManagerProtocolError::new(self.operation_id, OutputManagerError::BaseNodeChanged))?; debug!( target: LOG_TARGET, - "Starting TXO validation protocol (Id: {})", self.operation_id, + "Starting TXO validation protocol with peer {} (Id: {})", base_node_peer, self.operation_id, ); let last_mined_header = self.check_for_reorgs(&mut base_node_client).await?; @@ -99,10 +107,11 @@ where self.update_spent_outputs(&mut base_node_client, last_mined_header) .await?; + self.publish_event(OutputManagerEvent::TxoValidationSuccess(self.operation_id)); debug!( target: LOG_TARGET, - "Finished TXO validation protocol (Id: {})", self.operation_id, + "Finished TXO validation protocol from base node {} (Id: {})", base_node_peer, self.operation_id, ); Ok(self.operation_id) } @@ -233,6 +242,7 @@ where batch.len(), self.operation_id ); + let (mined, unmined, tip_height) = self .query_base_node_for_outputs(batch, wallet_client) .await diff --git a/base_layer/wallet/src/transaction_service/error.rs b/base_layer/wallet/src/transaction_service/error.rs index eb934ecb75..fe5107bb3e 100644 --- a/base_layer/wallet/src/transaction_service/error.rs +++ b/base_layer/wallet/src/transaction_service/error.rs @@ -94,6 +94,8 @@ pub enum TransactionServiceError { AttemptedToBroadcastCoinbaseTransaction(TxId), #[error("No Base Node public keys are provided for Base chain broadcast and monitoring")] NoBaseNodeKeysProvided, + #[error("Base node changed during {task_name}")] + BaseNodeChanged { task_name: &'static str }, #[error("Error sending data to Protocol via registered channels")] ProtocolChannelError, #[error("Transaction detected as rejected by mempool")] diff --git a/base_layer/wallet/src/transaction_service/protocols/transaction_validation_protocol.rs b/base_layer/wallet/src/transaction_service/protocols/transaction_validation_protocol.rs index 92ddedbef6..e6bbf2a64b 100644 --- a/base_layer/wallet/src/transaction_service/protocols/transaction_validation_protocol.rs +++ b/base_layer/wallet/src/transaction_service/protocols/transaction_validation_protocol.rs @@ -29,7 +29,7 @@ use std::{ use log::*; use tari_common_types::{ transaction::{TransactionStatus, TxId}, - types::BlockHash, + types::{BlockHash, Signature}, }; use tari_comms::protocol::rpc::{RpcError::RequestFailed, RpcStatusCode::NotFound}; use tari_core::{ @@ -51,6 +51,7 @@ use crate::{ handle::{TransactionEvent, TransactionEventSender}, storage::{ database::{TransactionBackend, TransactionDatabase}, + models::TxCancellationReason, sqlite_db::UnconfirmedTransactionInfo, }, }, @@ -67,9 +68,6 @@ pub struct TransactionValidationProtocol TransactionValidationProtocol @@ -504,10 +502,6 @@ where tx_id: TxId, status: &TransactionStatus, ) -> Result<(), TransactionServiceProtocolError> { - self.db - .set_transaction_as_unmined(tx_id) - .for_protocol(self.operation_id)?; - if *status == TransactionStatus::Coinbase { if let Err(e) = self.output_manager_handle.set_coinbase_abandoned(tx_id, false).await { warn!( @@ -520,6 +514,10 @@ where }; } + self.db + .set_transaction_as_unmined(tx_id) + .for_protocol(self.operation_id)?; + self.publish_event(TransactionEvent::TransactionBroadcast(tx_id)); Ok(()) } diff --git a/base_layer/wallet/src/transaction_service/service.rs b/base_layer/wallet/src/transaction_service/service.rs index a159e800a3..094ebae3b3 100644 --- a/base_layer/wallet/src/transaction_service/service.rs +++ b/base_layer/wallet/src/transaction_service/service.rs @@ -2180,9 +2180,12 @@ where JoinHandle>>, >, ) -> Result { - if !self.connectivity().is_base_node_set() { - return Err(TransactionServiceError::NoBaseNodeKeysProvided); - } + let current_base_node = self + .resources + .connectivity + .get_current_base_node_id() + .ok_or(TransactionServiceError::NoBaseNodeKeysProvided)?; + trace!(target: LOG_TARGET, "Starting transaction validation protocol"); let id = OperationId::new_random(); @@ -2195,7 +2198,29 @@ where self.resources.output_manager_service.clone(), ); - let join_handle = tokio::spawn(protocol.execute()); + let mut base_node_watch = self.connectivity().get_current_base_node_watcher(); + + let join_handle = tokio::spawn(async move { + let exec_fut = protocol.execute(); + tokio::pin!(exec_fut); + loop { + tokio::select! { + result = &mut exec_fut => { + return result; + }, + _ = base_node_watch.changed() => { + if let Some(peer) = base_node_watch.borrow().as_ref() { + if peer.node_id != current_base_node { + debug!(target: LOG_TARGET, "Base node changed, exiting transaction validation protocol"); + return Err(TransactionServiceProtocolError::new(id, TransactionServiceError::BaseNodeChanged { + task_name: "transaction validation_protocol", + })); + } + } + } + } + } + }); join_handles.push(join_handle); Ok(id) diff --git a/base_layer/wallet/tests/output_manager_service_tests/service.rs b/base_layer/wallet/tests/output_manager_service_tests/service.rs index c97bf9d096..de60490199 100644 --- a/base_layer/wallet/tests/output_manager_service_tests/service.rs +++ b/base_layer/wallet/tests/output_manager_service_tests/service.rs @@ -146,17 +146,20 @@ async fn setup_output_manager_service