diff --git a/applications/tari_base_node/src/commands/command_handler.rs b/applications/tari_base_node/src/commands/command_handler.rs index 606b16f59e..0e02f8184a 100644 --- a/applications/tari_base_node/src/commands/command_handler.rs +++ b/applications/tari_base_node/src/commands/command_handler.rs @@ -23,6 +23,7 @@ use std::{ cmp, io::{self, Write}, + ops::Deref, str::FromStr, string::ToString, sync::Arc, @@ -364,7 +365,7 @@ impl CommandHandler { println!("🌎 Peer discovery started."); let peer = self .discovery_service - .discover_peer(dest_pubkey.clone(), NodeDestination::PublicKey(dest_pubkey)) + .discover_peer(dest_pubkey.deref().clone(), NodeDestination::PublicKey(dest_pubkey)) .await?; println!("⚡️ Discovery succeeded in {}ms!", start.elapsed().as_millis()); println!("This peer was found:"); diff --git a/applications/tari_console_wallet/src/automation/commands.rs b/applications/tari_console_wallet/src/automation/commands.rs index 1558fd933d..769dcb83b1 100644 --- a/applications/tari_console_wallet/src/automation/commands.rs +++ b/applications/tari_console_wallet/src/automation/commands.rs @@ -322,14 +322,17 @@ pub async fn discover_peer( ) -> Result<(), CommandError> { use ParsedArgument::*; let dest_public_key = match args[0].clone() { - PublicKey(key) => Ok(Box::new(key)), + PublicKey(key) => Ok(key), _ => Err(CommandError::Argument), }?; let start = Instant::now(); println!("🌎 Peer discovery started."); match dht_service - .discover_peer(dest_public_key.clone(), NodeDestination::PublicKey(dest_public_key)) + .discover_peer( + dest_public_key.clone(), + NodeDestination::PublicKey(Box::new(dest_public_key)), + ) .await { Ok(peer) => { diff --git a/applications/tari_validator_node/src/dan_node.rs b/applications/tari_validator_node/src/dan_node.rs index 069f8861c2..1b2097a1b0 100644 --- a/applications/tari_validator_node/src/dan_node.rs +++ b/applications/tari_validator_node/src/dan_node.rs @@ -213,9 +213,7 @@ impl DanNode { let chain_storage = SqliteStorageService {}; let wallet_client = GrpcWalletClient::new(config.wallet_grpc_address); let checkpoint_manager = ConcreteCheckpointManager::new(asset_definition.clone(), wallet_client); - let connectivity = handles.expect_handle(); - let validator_node_client_factory = - TariCommsValidatorNodeClientFactory::new(connectivity, dht.discovery_service_requester()); + let validator_node_client_factory = TariCommsValidatorNodeClientFactory::new(dht.dht_requester()); let mut consensus_worker = ConsensusWorker::::new( receiver, outbound, diff --git a/applications/tari_validator_node/src/main.rs b/applications/tari_validator_node/src/main.rs index a6790869ba..ca5c476513 100644 --- a/applications/tari_validator_node/src/main.rs +++ b/applications/tari_validator_node/src/main.rs @@ -44,7 +44,7 @@ use tari_common::{ exit_codes::{ExitCode, ExitError}, GlobalConfig, }; -use tari_comms::{connectivity::ConnectivityRequester, peer_manager::PeerFeatures, NodeIdentity}; +use tari_comms::{peer_manager::PeerFeatures, NodeIdentity}; use tari_comms_dht::Dht; use tari_dan_core::services::{ConcreteAssetProcessor, ConcreteAssetProxy, MempoolServiceHandle, ServiceSpecification}; use tari_dan_storage_sqlite::SqliteDbFactory; @@ -124,10 +124,8 @@ async fn run_node(config: GlobalConfig, create_id: bool) -> Result<(), ExitError .await?; let asset_processor = ConcreteAssetProcessor::default(); - let validator_node_client_factory = TariCommsValidatorNodeClientFactory::new( - handles.expect_handle::(), - handles.expect_handle::().discovery_service_requester(), - ); + let validator_node_client_factory = + TariCommsValidatorNodeClientFactory::new(handles.expect_handle::().dht_requester()); let asset_proxy: ConcreteAssetProxy = ConcreteAssetProxy::new( GrpcBaseNodeClient::new(validator_node_config.base_node_grpc_address), validator_node_client_factory, diff --git a/applications/tari_validator_node/src/p2p/services/rpc_client.rs b/applications/tari_validator_node/src/p2p/services/rpc_client.rs index 3add896ecc..fea3ac0f92 100644 --- a/applications/tari_validator_node/src/p2p/services/rpc_client.rs +++ b/applications/tari_validator_node/src/p2p/services/rpc_client.rs @@ -25,13 +25,8 @@ use std::convert::TryInto; use async_trait::async_trait; use log::*; use tari_common_types::types::PublicKey; -use tari_comms::{ - connection_manager::ConnectionManagerError, - connectivity::{ConnectivityError, ConnectivityRequester}, - peer_manager::{NodeId, PeerManagerError}, - PeerConnection, -}; -use tari_comms_dht::{envelope::NodeDestination, DhtDiscoveryRequester}; +use tari_comms::PeerConnection; +use tari_comms_dht::DhtRequester; use tari_crypto::tari_utilities::ByteArray; use tari_dan_core::{ models::{Node, SchemaState, SideChainBlock, StateOpLogEntry, TemplateId, TreeNodeHash}, @@ -44,48 +39,14 @@ use crate::p2p::{proto::validator_node as proto, rpc}; const LOG_TARGET: &str = "tari::validator_node::p2p::services::rpc_client"; pub struct TariCommsValidatorNodeRpcClient { - connectivity: ConnectivityRequester, - dht_discovery: DhtDiscoveryRequester, + dht: DhtRequester, address: PublicKey, } impl TariCommsValidatorNodeRpcClient { async fn create_connection(&mut self) -> Result { - match self.connectivity.dial_peer(NodeId::from(self.address.clone())).await { - Ok(connection) => Ok(connection), - Err(connectivity_error) => { - dbg!(&connectivity_error); - match &connectivity_error { - ConnectivityError::ConnectionFailed(err) => { - match err { - ConnectionManagerError::PeerConnectionError(_) | - ConnectionManagerError::DialConnectFailedAllAddresses | - ConnectionManagerError::PeerIdentityNoValidAddresses | - ConnectionManagerError::PeerManagerError(PeerManagerError::PeerNotFoundError) => { - // Try discover, then dial again - // TODO: Should make discovery and connect the responsibility of the DHT layer - self.dht_discovery - .discover_peer( - Box::new(self.address.clone()), - NodeDestination::PublicKey(Box::new(self.address.clone())), - ) - .await?; - if let Some(conn) = self - .connectivity - .get_connection(NodeId::from(self.address.clone())) - .await? - { - return Ok(conn); - } - Ok(self.connectivity.dial_peer(NodeId::from(self.address.clone())).await?) - }, - _ => Err(connectivity_error.into()), - } - }, - _ => Err(connectivity_error.into()), - } - }, - } + let conn = self.dht.dial_or_discover_peer(self.address.clone()).await?; + Ok(conn) } } @@ -280,16 +241,12 @@ impl ValidatorNodeRpcClient for TariCommsValidatorNodeRpcClient { #[derive(Clone)] pub struct TariCommsValidatorNodeClientFactory { - connectivity_requester: ConnectivityRequester, - dht_discovery: DhtDiscoveryRequester, + dht: DhtRequester, } impl TariCommsValidatorNodeClientFactory { - pub fn new(connectivity_requester: ConnectivityRequester, dht_discovery: DhtDiscoveryRequester) -> Self { - Self { - connectivity_requester, - dht_discovery, - } + pub fn new(dht: DhtRequester) -> Self { + Self { dht } } } @@ -299,8 +256,7 @@ impl ValidatorNodeClientFactory for TariCommsValidatorNodeClientFactory { fn create_client(&self, address: &Self::Addr) -> Self::Client { TariCommsValidatorNodeRpcClient { - connectivity: self.connectivity_requester.clone(), - dht_discovery: self.dht_discovery.clone(), + dht: self.dht.clone(), address: address.clone(), } } diff --git a/comms/dht/examples/memory_net/utilities.rs b/comms/dht/examples/memory_net/utilities.rs index 14787b5175..046e030c59 100644 --- a/comms/dht/examples/memory_net/utilities.rs +++ b/comms/dht/examples/memory_net/utilities.rs @@ -144,7 +144,7 @@ pub async fn discovery(wallets: &[TestNode], messaging_events_rx: &mut NodeEvent .dht .discovery_service_requester() .discover_peer( - Box::new(wallet2.node_identity().public_key().clone()), + wallet2.node_identity().public_key().clone(), wallet2.node_identity().node_id().clone().into(), ) .await; diff --git a/comms/dht/src/actor.rs b/comms/dht/src/actor.rs index b648d65f1a..5b0b38c207 100644 --- a/comms/dht/src/actor.rs +++ b/comms/dht/src/actor.rs @@ -27,15 +27,17 @@ //! //! [DhtRequest]: ./enum.DhtRequest.html -use std::{cmp, fmt, fmt::Display, sync::Arc}; +use std::{cmp, fmt, fmt::Display, sync::Arc, time::Instant}; use chrono::{DateTime, Utc}; use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; use log::*; use tari_comms::{ + connection_manager::ConnectionManagerError, connectivity::{ConnectivityError, ConnectivityRequester, ConnectivitySelection}, peer_manager::{NodeId, NodeIdentity, PeerFeatures, PeerManager, PeerManagerError, PeerQuery, PeerQuerySortBy}, types::CommsPublicKey, + PeerConnection, }; use tari_crypto::tari_utilities::hex::Hex; use tari_shutdown::ShutdownSignal; @@ -56,6 +58,7 @@ use crate::{ proto::{dht::JoinMessage, envelope::DhtMessageType}, storage::{DbConnection, DhtDatabase, DhtMetadataKey, StorageError}, DhtConfig, + DhtDiscoveryRequester, }; const LOG_TARGET: &str = "comms::dht::actor"; @@ -107,6 +110,10 @@ pub enum DhtRequest { SelectPeers(BroadcastStrategy, oneshot::Sender>), GetMetadata(DhtMetadataKey, oneshot::Sender>, DhtActorError>>), SetMetadata(DhtMetadataKey, Vec, oneshot::Sender>), + DialDiscoverPeer { + public_key: CommsPublicKey, + reply: oneshot::Sender>, + }, } impl Display for DhtRequest { @@ -130,6 +137,7 @@ impl Display for DhtRequest { SetMetadata(key, value, _) => { write!(f, "SetMetadata (key={}, value={} bytes)", key, value.len()) }, + DialDiscoverPeer { public_key, .. } => write!(f, "DialDiscoverPeer(public_key={})", public_key), } } } @@ -199,6 +207,19 @@ impl DhtRequester { self.sender.send(DhtRequest::SetMetadata(key, bytes, reply_tx)).await?; reply_rx.await.map_err(|_| DhtActorError::ReplyCanceled)? } + + /// Attempt to dial a peer. If the peer is not known, a discovery will be initiated. If discovery succeeds, a + /// connection to the peer will be returned. + pub async fn dial_or_discover_peer(&mut self, public_key: CommsPublicKey) -> Result { + let (reply_tx, reply_rx) = oneshot::channel(); + self.sender + .send(DhtRequest::DialDiscoverPeer { + public_key, + reply: reply_tx, + }) + .await?; + reply_rx.await.map_err(|_| DhtActorError::ReplyCanceled)? + } } pub struct DhtActor { @@ -208,6 +229,7 @@ pub struct DhtActor { outbound_requester: OutboundMessageRequester, connectivity: ConnectivityRequester, config: Arc, + discovery: DhtDiscoveryRequester, shutdown_signal: ShutdownSignal, request_rx: mpsc::Receiver, msg_hash_dedup_cache: DedupCacheDatabase, @@ -222,6 +244,7 @@ impl DhtActor { connectivity: ConnectivityRequester, outbound_requester: OutboundMessageRequester, request_rx: mpsc::Receiver, + discovery: DhtDiscoveryRequester, shutdown_signal: ShutdownSignal, ) -> Self { debug!( @@ -238,6 +261,7 @@ impl DhtActor { peer_manager, connectivity, node_identity, + discovery, shutdown_signal, request_rx, } @@ -385,6 +409,17 @@ impl DhtActor { Ok(()) }) }, + DialDiscoverPeer { public_key, reply } => { + let connectivity = self.connectivity.clone(); + let discovery = self.discovery.clone(); + let peer_manager = self.peer_manager.clone(); + Box::pin(async move { + let mut task = DiscoveryDialTask::new(connectivity, peer_manager, discovery); + let result = task.run(public_key).await; + let _ = reply.send(result); + Ok(()) + }) + }, } } @@ -709,12 +744,78 @@ impl DhtActor { } } +struct DiscoveryDialTask { + connectivity: ConnectivityRequester, + peer_manager: Arc, + discovery: DhtDiscoveryRequester, +} + +impl DiscoveryDialTask { + pub fn new( + connectivity: ConnectivityRequester, + peer_manager: Arc, + discovery: DhtDiscoveryRequester, + ) -> Self { + Self { + connectivity, + peer_manager, + discovery, + } + } + + pub async fn run(&mut self, public_key: CommsPublicKey) -> Result { + if self.peer_manager.exists(&public_key).await { + let node_id = NodeId::from_public_key(&public_key); + match self.connectivity.dial_peer(node_id).await { + Ok(conn) => Ok(conn), + Err(ConnectivityError::ConnectionFailed(err)) => match err { + ConnectionManagerError::ConnectFailedMaximumAttemptsReached | + ConnectionManagerError::DialConnectFailedAllAddresses => { + debug!( + target: LOG_TARGET, + "Dial failed for peer {}. Attempting discovery.", public_key + ); + self.discover_peer(public_key).await + }, + err => Err(ConnectivityError::from(err).into()), + }, + Err(err) => Err(err.into()), + } + } else { + debug!( + target: LOG_TARGET, + "Peer '{}' not found, initiating discovery", public_key + ); + self.discover_peer(public_key).await + } + } + + async fn discover_peer(&mut self, public_key: CommsPublicKey) -> Result { + let node_id = NodeId::from_public_key(&public_key); + let timer = Instant::now(); + let _ = self + .discovery + .discover_peer(public_key.clone(), public_key.into()) + .await?; + debug!( + target: LOG_TARGET, + "Discovery succeeded for peer {} in {:.2?}", + node_id, + timer.elapsed() + ); + let conn = self.connectivity.dial_peer(node_id).await?; + Ok(conn) + } +} + #[cfg(test)] mod test { + use std::time::Duration; + use chrono::{DateTime, Utc}; use tari_comms::{ runtime, - test_utils::mocks::{create_connectivity_mock, create_peer_connection_mock_pair}, + test_utils::mocks::{create_connectivity_mock, create_peer_connection_mock_pair, ConnectivityManagerMockState}, }; use tari_shutdown::Shutdown; use tari_test_utils::random; @@ -723,7 +824,13 @@ mod test { use crate::{ broadcast_strategy::BroadcastClosestRequest, envelope::NodeDestination, - test_utils::{build_peer_manager, make_client_identity, make_node_identity}, + test_utils::{ + build_peer_manager, + create_dht_discovery_mock, + make_client_identity, + make_node_identity, + DhtDiscoveryMockState, + }, }; async fn db_connection() -> DbConnection { @@ -742,6 +849,7 @@ mod test { let (actor_tx, actor_rx) = mpsc::channel(1); let mut requester = DhtRequester::new(actor_tx); let outbound_requester = OutboundMessageRequester::new(out_tx); + let (discovery, _) = create_dht_discovery_mock(Duration::from_secs(10)); let shutdown = Shutdown::new(); let actor = DhtActor::new( Default::default(), @@ -751,6 +859,7 @@ mod test { connectivity_manager, outbound_requester, actor_rx, + discovery, shutdown.to_signal(), ); @@ -761,6 +870,94 @@ mod test { assert_eq!(params.dht_message_type, DhtMessageType::Join); } + mod discovery_dial_peer { + use super::*; + use crate::test_utils::make_peer; + + async fn setup( + shutdown_signal: ShutdownSignal, + ) -> ( + DhtRequester, + Arc, + ConnectivityManagerMockState, + DhtDiscoveryMockState, + Arc, + ) { + let node_identity = make_node_identity(); + let peer_manager = build_peer_manager(); + let (out_tx, _) = mpsc::channel(1); + let (connectivity_manager, mock) = create_connectivity_mock(); + let connectivity_mock = mock.get_shared_state(); + mock.spawn(); + let (actor_tx, actor_rx) = mpsc::channel(1); + let requester = DhtRequester::new(actor_tx); + let outbound_requester = OutboundMessageRequester::new(out_tx); + let (discovery, mock) = create_dht_discovery_mock(Duration::from_secs(10)); + let discovery_mock = mock.get_shared_state(); + mock.spawn(); + DhtActor::new( + Default::default(), + db_connection().await, + node_identity.clone(), + peer_manager.clone(), + connectivity_manager, + outbound_requester, + actor_rx, + discovery, + shutdown_signal, + ) + .spawn(); + + ( + requester, + node_identity, + connectivity_mock, + discovery_mock, + peer_manager, + ) + } + + #[runtime::test] + async fn it_discovers_a_peer() { + let shutdown = Shutdown::new(); + let (mut dht, node_identity, connectivity_mock, discovery_mock, _) = setup(shutdown.to_signal()).await; + let peer = make_peer(); + discovery_mock.set_discover_peer_response(peer.clone()); + let (conn1, _, _, _) = create_peer_connection_mock_pair(node_identity.to_peer(), peer.clone()).await; + connectivity_mock.add_active_connection(conn1).await; + + let conn = dht.dial_or_discover_peer(peer.public_key).await.unwrap(); + assert_eq!(*conn.peer_node_id(), peer.node_id); + assert_eq!(discovery_mock.call_count(), 1); + } + + #[runtime::test] + async fn it_gets_active_peer_connection() { + let shutdown = Shutdown::new(); + let (mut dht, node_identity, connectivity_mock, discovery_mock, peer_manager) = + setup(shutdown.to_signal()).await; + let peer = make_peer(); + peer_manager.add_peer(peer.clone()).await.unwrap(); + let (conn1, _, _, _) = create_peer_connection_mock_pair(node_identity.to_peer(), peer.clone()).await; + connectivity_mock.add_active_connection(conn1).await; + + let conn = dht.dial_or_discover_peer(peer.public_key).await.unwrap(); + assert_eq!(*conn.peer_node_id(), peer.node_id); + assert_eq!(discovery_mock.call_count(), 0); + assert_eq!(connectivity_mock.call_count().await, 1); + } + + #[runtime::test] + async fn it_errors_if_discovery_fails_for_unknown_peer() { + let shutdown = Shutdown::new(); + let (mut dht, _, connectivity_mock, discovery_mock, _) = setup(shutdown.to_signal()).await; + let peer = make_peer(); + let _ = dht.dial_or_discover_peer(peer.public_key.clone()).await.unwrap_err(); + assert_eq!(discovery_mock.call_count(), 1); + assert_eq!(connectivity_mock.call_count().await, 0); + } + } + #[runtime::test] async fn insert_message_signature() { let node_identity = make_node_identity(); @@ -770,6 +967,7 @@ mod test { let (out_tx, _) = mpsc::channel(1); let (actor_tx, actor_rx) = mpsc::channel(1); let mut requester = DhtRequester::new(actor_tx); + let (discovery, _) = create_dht_discovery_mock(Duration::from_secs(10)); let outbound_requester = OutboundMessageRequester::new(out_tx); let shutdown = Shutdown::new(); let actor = DhtActor::new( @@ -780,6 +978,7 @@ mod test { connectivity_manager, outbound_requester, actor_rx, + discovery, shutdown.to_signal(), ); @@ -813,6 +1012,7 @@ mod test { let (actor_tx, actor_rx) = mpsc::channel(1); let mut requester = DhtRequester::new(actor_tx); let outbound_requester = OutboundMessageRequester::new(out_tx); + let (discovery, _) = create_dht_discovery_mock(Duration::from_secs(10)); let shutdown = Shutdown::new(); // Note: This must be equal or larger than the minimum dedup cache capacity for DedupCacheDatabase let capacity = 10; @@ -827,6 +1027,7 @@ mod test { connectivity_manager, outbound_requester, actor_rx, + discovery, shutdown.to_signal(), ); @@ -899,6 +1100,7 @@ mod test { let connectivity_manager_mock_state = mock.get_shared_state(); mock.spawn(); + let (discovery, _) = create_dht_discovery_mock(Duration::from_secs(10)); let (conn_in, _, conn_out, _) = create_peer_connection_mock_pair(client_node_identity.to_peer(), node_identity.to_peer()).await; connectivity_manager_mock_state.add_active_connection(conn_in).await; @@ -918,6 +1120,7 @@ mod test { connectivity_manager, outbound_requester, actor_rx, + discovery, shutdown.to_signal(), ); @@ -1007,6 +1210,7 @@ mod test { let (connectivity_manager, mock) = create_connectivity_mock(); mock.spawn(); let mut requester = DhtRequester::new(actor_tx); + let (discovery, _) = create_dht_discovery_mock(Duration::from_secs(10)); let outbound_requester = OutboundMessageRequester::new(out_tx); let mut shutdown = Shutdown::new(); let actor = DhtActor::new( @@ -1017,6 +1221,7 @@ mod test { connectivity_manager, outbound_requester, actor_rx, + discovery, shutdown.to_signal(), ); diff --git a/comms/dht/src/dht.rs b/comms/dht/src/dht.rs index fcdb8826a8..775074ec2c 100644 --- a/comms/dht/src/dht.rs +++ b/comms/dht/src/dht.rs @@ -181,6 +181,7 @@ impl Dht { self.connectivity.clone(), self.outbound_requester(), request_receiver, + self.discovery_service_requester(), shutdown_signal, ) } diff --git a/comms/dht/src/discovery/mod.rs b/comms/dht/src/discovery/mod.rs index 4a41e8faca..31b92622af 100644 --- a/comms/dht/src/discovery/mod.rs +++ b/comms/dht/src/discovery/mod.rs @@ -21,8 +21,12 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. mod error; +pub use error::DhtDiscoveryError; + mod requester; -mod service; +pub use requester::DhtDiscoveryRequester; pub(crate) use self::requester::DhtDiscoveryRequest; -pub use self::{error::DhtDiscoveryError, requester::DhtDiscoveryRequester, service::DhtDiscoveryService}; + +mod service; +pub use service::DhtDiscoveryService; diff --git a/comms/dht/src/discovery/requester.rs b/comms/dht/src/discovery/requester.rs index 567c545497..c7195074ea 100644 --- a/comms/dht/src/discovery/requester.rs +++ b/comms/dht/src/discovery/requester.rs @@ -31,7 +31,8 @@ use tokio::{ time, }; -use crate::{discovery::DhtDiscoveryError, envelope::NodeDestination, proto::dht::DiscoveryResponseMessage}; +use super::DhtDiscoveryError; +use crate::{envelope::NodeDestination, proto::dht::DiscoveryResponseMessage}; #[derive(Debug)] pub enum DhtDiscoveryRequest { @@ -85,14 +86,14 @@ impl DhtDiscoveryRequester { /// quicker discovery times. pub async fn discover_peer( &mut self, - dest_public_key: Box, + dest_public_key: CommsPublicKey, destination: NodeDestination, ) -> Result { let (reply_tx, reply_rx) = oneshot::channel(); self.sender .send(DhtDiscoveryRequest::DiscoverPeer( - dest_public_key, + Box::new(dest_public_key), destination, reply_tx, )) @@ -109,7 +110,7 @@ impl DhtDiscoveryRequester { .map_err(|_| DhtDiscoveryError::ReplyCanceled)? } - pub async fn notify_discovery_response_received( + pub(crate) async fn notify_discovery_response_received( &mut self, response: DiscoveryResponseMessage, ) -> Result<(), DhtDiscoveryError> { diff --git a/comms/dht/src/discovery/service.rs b/comms/dht/src/discovery/service.rs index 1d883e468d..84a9c3c414 100644 --- a/comms/dht/src/discovery/service.rs +++ b/comms/dht/src/discovery/service.rs @@ -387,7 +387,7 @@ mod test { let dest_public_key = Box::new(CommsPublicKey::default()); let result = requester .discover_peer( - dest_public_key.clone(), + *dest_public_key.clone(), NodeDestination::PublicKey(dest_public_key.clone()), ) .await; diff --git a/comms/dht/src/outbound/broadcast.rs b/comms/dht/src/outbound/broadcast.rs index 5b47ca2934..0979d5c0e8 100644 --- a/comms/dht/src/outbound/broadcast.rs +++ b/comms/dht/src/outbound/broadcast.rs @@ -381,7 +381,7 @@ where S: Service // Peer not found, let's try and discover it match self .dht_discovery_requester - .discover_peer(dest_public_key.clone(), NodeDestination::PublicKey(dest_public_key)) + .discover_peer(*dest_public_key.clone(), NodeDestination::PublicKey(dest_public_key)) .await { // Peer found! @@ -577,7 +577,6 @@ mod test { create_dht_discovery_mock, make_peer, service_spy, - DhtDiscoveryMockState, }, }; @@ -609,7 +608,7 @@ mod test { )); let (dht_requester, dht_mock) = create_dht_actor_mock(10); - let (dht_discover_requester, _) = create_dht_discovery_mock(10, Duration::from_secs(10)); + let (dht_discover_requester, _) = create_dht_discovery_mock(Duration::from_secs(10)); let mock_state = dht_mock.get_shared_state(); mock_state.set_select_peers_response(vec![example_peer.clone(), other_peer.clone()]); @@ -658,7 +657,7 @@ mod test { ); let (dht_requester, dht_mock) = create_dht_actor_mock(10); task::spawn(dht_mock.run()); - let (dht_discover_requester, _) = create_dht_discovery_mock(10, Duration::from_secs(10)); + let (dht_discover_requester, _) = create_dht_discovery_mock(Duration::from_secs(10)); let spy = service_spy(); let mut service = BroadcastMiddleware::new( @@ -700,10 +699,9 @@ mod test { ); let (dht_requester, dht_mock) = create_dht_actor_mock(10); task::spawn(dht_mock.run()); - let (dht_discover_requester, mut discovery_mock) = create_dht_discovery_mock(10, Duration::from_secs(10)); - let dht_discovery_state = DhtDiscoveryMockState::new(); - discovery_mock.set_shared_state(dht_discovery_state.clone()); - task::spawn(discovery_mock.run()); + let (dht_discover_requester, discovery_mock) = create_dht_discovery_mock(Duration::from_secs(10)); + let dht_discovery_state = discovery_mock.get_shared_state(); + discovery_mock.spawn(); let peer_to_discover = make_peer(); dht_discovery_state.set_discover_peer_response(peer_to_discover.clone()); diff --git a/comms/dht/src/test_utils/dht_actor_mock.rs b/comms/dht/src/test_utils/dht_actor_mock.rs index 91407ef702..eccf145bf8 100644 --- a/comms/dht/src/test_utils/dht_actor_mock.rs +++ b/comms/dht/src/test_utils/dht_actor_mock.rs @@ -134,6 +134,7 @@ impl DhtActorMock { self.state.settings.write().unwrap().insert(key.to_string(), value); reply_tx.send(Ok(())).unwrap(); }, + DialDiscoverPeer { .. } => unimplemented!(), } } } diff --git a/comms/dht/src/test_utils/dht_discovery_mock.rs b/comms/dht/src/test_utils/dht_discovery_mock.rs index 91975c8c99..8667da8305 100644 --- a/comms/dht/src/test_utils/dht_discovery_mock.rs +++ b/comms/dht/src/test_utils/dht_discovery_mock.rs @@ -31,36 +31,36 @@ use std::{ use log::*; use tari_comms::peer_manager::Peer; -use tokio::sync::mpsc; +use tokio::{sync::mpsc, task}; use crate::{ discovery::{DhtDiscoveryRequest, DhtDiscoveryRequester}, - test_utils::make_peer, + DhtDiscoveryError, }; const LOG_TARGET: &str = "comms::dht::discovery_mock"; -pub fn create_dht_discovery_mock(buf_size: usize, timeout: Duration) -> (DhtDiscoveryRequester, DhtDiscoveryMock) { - let (tx, rx) = mpsc::channel(buf_size); +pub fn create_dht_discovery_mock(timeout: Duration) -> (DhtDiscoveryRequester, DhtDiscoveryMock) { + let (tx, rx) = mpsc::channel(10); (DhtDiscoveryRequester::new(tx, timeout), DhtDiscoveryMock::new(rx)) } #[derive(Debug, Clone)] pub struct DhtDiscoveryMockState { call_count: Arc, - discover_peer: Arc>, + discover_peer: Arc>>, } impl DhtDiscoveryMockState { pub fn new() -> Self { Self { call_count: Arc::new(AtomicUsize::new(0)), - discover_peer: Arc::new(RwLock::new(make_peer())), + discover_peer: Arc::new(RwLock::new(None)), } } pub fn set_discover_peer_response(&self, peer: Peer) -> &Self { - *self.discover_peer.write().unwrap() = peer; + *self.discover_peer.write().unwrap() = Some(peer); self } @@ -86,14 +86,16 @@ impl DhtDiscoveryMock { } } - pub fn set_shared_state(&mut self, state: DhtDiscoveryMockState) { - self.state = state; + pub fn get_shared_state(&self) -> DhtDiscoveryMockState { + self.state.clone() } - pub async fn run(mut self) { - while let Some(req) = self.receiver.recv().await { - self.handle_request(req).await; - } + pub fn spawn(mut self) { + task::spawn(async move { + while let Some(req) = self.receiver.recv().await { + self.handle_request(req).await; + } + }); } async fn handle_request(&self, req: DhtDiscoveryRequest) { @@ -103,7 +105,9 @@ impl DhtDiscoveryMock { match req { DiscoverPeer(_, _, reply_tx) => { let lock = self.state.discover_peer.read().unwrap(); - reply_tx.send(Ok(lock.clone())).unwrap(); + reply_tx + .send(lock.clone().ok_or(DhtDiscoveryError::DiscoveryTimeout)) + .unwrap(); }, NotifyDiscoveryResponseReceived(_) => {}, } diff --git a/comms/dht/tests/dht.rs b/comms/dht/tests/dht.rs index e05ae253f7..61a602f8d9 100644 --- a/comms/dht/tests/dht.rs +++ b/comms/dht/tests/dht.rs @@ -357,7 +357,7 @@ async fn dht_discover_propagation() { .dht .discovery_service_requester() .discover_peer( - Box::new(node_D.node_identity().public_key().clone()), + node_D.node_identity().public_key().clone(), node_D.node_identity().node_id().clone().into(), ) .await diff --git a/comms/src/connection_manager/common.rs b/comms/src/connection_manager/common.rs index 44e19e42ed..f78a096546 100644 --- a/comms/src/connection_manager/common.rs +++ b/comms/src/connection_manager/common.rs @@ -192,7 +192,7 @@ pub fn validate_peer_addresses<'a, A: IntoIterator>( validate_address(addr, allow_test_addrs)?; } if !has_address { - return Err(ConnectionManagerError::PeerHasNoAddresses); + return Err(ConnectionManagerError::PeerIdentityNoAddresses); } Ok(()) } diff --git a/comms/src/connection_manager/error.rs b/comms/src/connection_manager/error.rs index b4121f1eb5..b6e7fe6f37 100644 --- a/comms/src/connection_manager/error.rs +++ b/comms/src/connection_manager/error.rs @@ -87,7 +87,7 @@ pub enum ConnectionManagerError { #[error("Peer did not provide the identity timestamp")] PeerIdentityNoUpdatedTimestampProvided, #[error("Peer did not provide any public addresses")] - PeerHasNoAddresses, + PeerIdentityNoAddresses, } impl From for ConnectionManagerError { diff --git a/dan_layer/core/src/services/validator_node_rpc_client.rs b/dan_layer/core/src/services/validator_node_rpc_client.rs index 672b29351d..1215f518bf 100644 --- a/dan_layer/core/src/services/validator_node_rpc_client.rs +++ b/dan_layer/core/src/services/validator_node_rpc_client.rs @@ -27,7 +27,7 @@ use tari_comms::{ protocol::rpc::{RpcError, RpcStatus}, types::CommsPublicKey, }; -use tari_comms_dht::DhtDiscoveryError; +use tari_comms_dht::DhtActorError; use crate::{ models::{Node, SchemaState, SideChainBlock, StateOpLogEntry, TemplateId, TreeNodeHash}, @@ -91,6 +91,6 @@ pub enum ValidatorNodeClientError { RpcError(#[from] RpcError), #[error("Remote node returned error: {0}")] RpcStatusError(#[from] RpcStatus), - #[error("Dht Discovery error: {0}")] - DhtDiscoveryError(#[from] DhtDiscoveryError), + #[error("Dht error: {0}")] + DhtError(#[from] DhtActorError), }