diff --git a/src/transport/manager/address.rs b/src/transport/manager/address.rs index 6d813175..e18d7c05 100644 --- a/src/transport/manager/address.rs +++ b/src/transport/manager/address.rs @@ -102,7 +102,7 @@ impl AddressRecord { /// Update score of an address. pub fn update_score(&mut self, score: i32) { - self.score += score; + self.score = self.score.saturating_add(score); } /// Set `ConnectionId` for the [`AddressRecord`]. diff --git a/src/transport/manager/mod.rs b/src/transport/manager/mod.rs index 0129e241..be732b1f 100644 --- a/src/transport/manager/mod.rs +++ b/src/transport/manager/mod.rs @@ -44,7 +44,7 @@ use parking_lot::RwLock; use tokio::sync::mpsc::{channel, Receiver, Sender}; use std::{ - collections::{HashMap, HashSet}, + collections::{hash_map::Entry, HashMap, HashSet}, pin::Pin, sync::{ atomic::{AtomicUsize, Ordering}, @@ -640,43 +640,57 @@ impl TransportManager { { let mut peers = self.peers.write(); - match peers.get_mut(&remote_peer_id) { - None => { - drop(peers); + match peers.entry(remote_peer_id) { + Entry::Occupied(occupied) => { + let context = occupied.into_mut(); - tracing::debug!(target: LOG_TARGET, address = ?record.address(), "dial address first time entering dial state"); + // For a better address tacking, see: + // https://github.com/paritytech/litep2p/issues/180 + // + // TODO: context.addresses.insert(record.clone()); - self.peers.write().insert( - remote_peer_id, - PeerContext { - state: PeerState::Dialing { - record: record.clone(), - }, - addresses: AddressStore::new(), - secondary_connection: None, - }, + tracing::debug!( + target: LOG_TARGET, + peer = ?remote_peer_id, + state = ?context.state, + "peer state exists", ); - } - Some(PeerContext { - state: - PeerState::Dialing { .. } - | PeerState::Connected { .. } - | PeerState::Opening { .. }, - .. - }) => { - tracing::debug!(target: LOG_TARGET, address = ?record.address(), "dial address returning early"); - return Ok(()); - } - Some(PeerContext { ref mut state, .. }) => { - // TODO: verify that the address is not in `addresses` already - // addresses.insert(address.clone()); - tracing::debug!(target: LOG_TARGET, address = ?record.address(), ?state, "dial address entering dial state"); - *state = PeerState::Dialing { - record: record.clone(), - }; + match context.state { + PeerState::Connected { .. } => { + return Err(Error::AlreadyConnected); + } + PeerState::Dialing { .. } | PeerState::Opening { .. } => { + return Ok(()); + } + PeerState::Disconnected { + dial_record: Some(_), + } => { + tracing::debug!( + target: LOG_TARGET, + peer = ?remote_peer_id, + state = ?context.state, + "peer is already being dialed from a disconnected state" + ); + return Ok(()); + } + PeerState::Disconnected { dial_record: None } => { + context.state = PeerState::Dialing { + record: record.clone(), + }; + } + } } - } + Entry::Vacant(vacant) => { + vacant.insert(PeerContext { + state: PeerState::Dialing { + record: record.clone(), + }, + addresses: AddressStore::new(), + secondary_connection: None, + }); + } + }; } self.transports @@ -696,8 +710,6 @@ impl TransportManager { ?connection_id, "dial failed for a connection that doesn't exist", ); - debug_assert!(false); - Error::InvalidState })?; @@ -707,7 +719,7 @@ impl TransportManager { target: LOG_TARGET, ?peer, ?connection_id, - "dial failed for a peer that doens't exist", + "dial failed for a peer that doesn't exist", ); debug_assert!(false); @@ -720,6 +732,21 @@ impl TransportManager { ) { PeerState::Dialing { ref mut record } => { debug_assert_eq!(record.connection_id(), &Some(connection_id)); + if record.connection_id() != &Some(connection_id) { + tracing::warn!( + target: LOG_TARGET, + ?peer, + ?connection_id, + ?record, + "unknown dial failure for a dialing peer", + ); + + context.state = PeerState::Dialing { + record: record.clone(), + }; + debug_assert!(false); + return Ok(()); + } record.update_score(SCORE_CONNECT_FAILURE); context.addresses.insert(record.clone()); @@ -734,6 +761,23 @@ impl TransportManager { record, dial_record: Some(mut dial_record), } => { + if dial_record.connection_id() != &Some(connection_id) { + tracing::warn!( + target: LOG_TARGET, + ?peer, + ?connection_id, + ?record, + "unknown dial failure for a connected peer", + ); + + context.state = PeerState::Connected { + record, + dial_record: Some(dial_record), + }; + debug_assert!(false); + return Ok(()); + } + dial_record.update_score(SCORE_CONNECT_FAILURE); context.addresses.insert(dial_record); @@ -753,6 +797,22 @@ impl TransportManager { "dial failed for a disconnected peer", ); + if dial_record.connection_id() != &Some(connection_id) { + tracing::warn!( + target: LOG_TARGET, + ?peer, + ?connection_id, + ?dial_record, + "unknown dial failure for a disconnected peer", + ); + + context.state = PeerState::Disconnected { + dial_record: Some(dial_record), + }; + debug_assert!(false); + return Ok(()); + } + dial_record.update_score(SCORE_CONNECT_FAILURE); context.addresses.insert(dial_record); @@ -1098,7 +1158,7 @@ impl TransportManager { // since an inbound connection was removed, the outbound connection can be // removed from pending dials // - // all records have the same `ConnectionId` so it doens't matter which of them + // all records have the same `ConnectionId` so it doesn't matter which of them // is used to remove the pending dial self.pending_connections.remove( &records @@ -1336,8 +1396,6 @@ impl TransportManager { ?connection_id, "open failure but dial record doesn't exist", ); - - debug_assert!(false); return Err(Error::InvalidState); }; @@ -3144,6 +3202,25 @@ mod tests { }; manager.dial(peer).await.unwrap(); + + // Check state is unaltered. + { + let peers = manager.peers.read(); + let peer_context = peers.get(&peer).unwrap(); + + match &peer_context.state { + PeerState::Dialing { record } => { + assert_eq!( + record.address(), + &Multiaddr::empty() + .with(Protocol::Ip4(std::net::Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p(Multihash::from(peer))) + ); + } + state => panic!("invalid state: {state:?}"), + } + } } #[tokio::test] @@ -3656,10 +3733,62 @@ mod tests { } #[tokio::test] - async fn reject_unknown_secondary_connections_with_different_connection_ids() { - // This is the repro case for https://github.com/paritytech/litep2p/issues/172. + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ConnectionLimitsConfig::default(), + ); + manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new())); + + // Random peer ID. + let peer = PeerId::random(); + let (first_addr, first_connection_id) = setup_dial_addr(peer, 0); + let second_connection_id = ConnectionId::from(1); + let different_connection_id = ConnectionId::from(2); + + // Setup a connected peer with a dial record active. + { + let mut peers = manager.peers.write(); + + let state = PeerState::Connected { + record: AddressRecord::new(&peer, first_addr.clone(), 0, Some(first_connection_id)), + dial_record: Some(AddressRecord::new( + &peer, + first_addr.clone(), + 0, + Some(second_connection_id), + )), + }; + + let peer_context = PeerContext { + state, + secondary_connection: None, + addresses: AddressStore::from_iter(vec![first_addr.clone()].into_iter()), + }; + + peers.insert(peer, peer_context); + } + + // Establish a connection, however the connection ID is different. + let result = manager + .on_connection_established( + peer, + &Endpoint::dialer(first_addr.clone(), different_connection_id), + ) + .unwrap(); + assert_eq!(result, ConnectionEstablishedResult::Reject); + } + #[tokio::test] + async fn guard_against_secondary_connections_with_different_connection_ids() { + // This is the repro case for https://github.com/paritytech/litep2p/issues/172. let _ = tracing_subscriber::fmt() .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) .try_init(); @@ -3690,7 +3819,7 @@ mod tests { // Setup addresses. let (first_addr, first_connection_id) = setup_dial_addr(0); - let (second_addr, second_connection_id) = setup_dial_addr(1); + let (second_addr, _second_connection_id) = setup_dial_addr(1); let (remote_addr, remote_connection_id) = setup_dial_addr(2); // Step 1. Dialing state to peer. @@ -3758,16 +3887,17 @@ mod tests { } } - // Step 4. Dial by the second address to overwrite the state. + // Step 4. Dial by the second address and expect to not overwrite the state. manager.dial_address(second_addr.clone()).await.unwrap(); - // check the state of the peer. + // The state remains unchanged since we already have a dialing in flight. { let peers = manager.peers.read(); let peer_context = peers.get(&peer).unwrap(); match &peer_context.state { - PeerState::Dialing { record } => { - assert_eq!(record.address(), &second_addr); - assert_eq!(record.connection_id(), &Some(second_connection_id)); + PeerState::Disconnected { dial_record } => { + let dial_record = dial_record.as_ref().unwrap(); + assert_eq!(dial_record.address(), &first_addr); + assert_eq!(dial_record.connection_id(), &Some(first_connection_id)); } state => panic!("invalid state: {state:?}"), } @@ -3792,10 +3922,10 @@ mod tests { assert_eq!(record.address(), &remote_addr); assert_eq!(record.connection_id(), &Some(remote_connection_id)); - // We have overwritten the first dial record in step 4. + // We have not overwritten the first dial record in step 4. let dial_record = dial_record.as_ref().unwrap(); - assert_eq!(dial_record.address(), &second_addr); - assert_eq!(dial_record.connection_id(), &Some(second_connection_id)); + assert_eq!(dial_record.address(), &first_addr); + assert_eq!(dial_record.connection_id(), &Some(first_connection_id)); } state => panic!("invalid state: {state:?}"), } @@ -3808,8 +3938,81 @@ mod tests { &Endpoint::dialer(first_addr.clone(), first_connection_id), ) .unwrap(); - // We have rejected the connection because we have another secondary connection - // in flight (with a different connection ID). - assert_eq!(result, ConnectionEstablishedResult::Reject); + assert_eq!(result, ConnectionEstablishedResult::Accept); + } + + #[tokio::test] + async fn do_not_overwrite_dial_addresses() { + let _ = tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .try_init(); + + let (mut manager, _handle) = TransportManager::new( + Keypair::generate(), + HashSet::new(), + BandwidthSink::new(), + 8usize, + ConnectionLimitsConfig::default(), + ); + let peer = PeerId::random(); + let dial_address = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8888)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).unwrap(), + )); + + let connection_id = ConnectionId::from(0); + let transport = Box::new({ + let mut transport = DummyTransport::new(); + transport.inject_event(TransportEvent::ConnectionEstablished { + peer, + endpoint: Endpoint::listener(dial_address.clone(), connection_id), + }); + transport + }); + manager.register_transport(SupportedTransport::Tcp, transport); + + // First dial attempt. + manager.dial_address(dial_address.clone()).await.unwrap(); + // check the state of the peer. + { + let peers = manager.peers.read(); + let peer_context = peers.get(&peer).unwrap(); + match &peer_context.state { + PeerState::Dialing { record } => { + assert_eq!(record.address(), &dial_address); + } + state => panic!("invalid state: {state:?}"), + } + + // The address is not saved yet. + assert!(!peer_context.addresses.contains(&dial_address)); + } + + let second_address = Multiaddr::empty() + .with(Protocol::Ip4(Ipv4Addr::new(127, 0, 0, 1))) + .with(Protocol::Tcp(8889)) + .with(Protocol::P2p( + Multihash::from_bytes(&peer.to_bytes()).unwrap(), + )); + + // Second dial attempt with different address. + manager.dial_address(second_address.clone()).await.unwrap(); + // check the state of the peer. + { + let peers = manager.peers.read(); + let peer_context = peers.get(&peer).unwrap(); + match &peer_context.state { + // Must still be dialing the first address. + PeerState::Dialing { record } => { + assert_eq!(record.address(), &dial_address); + } + state => panic!("invalid state: {state:?}"), + } + + assert!(!peer_context.addresses.contains(&dial_address)); + assert!(!peer_context.addresses.contains(&second_address)); + } } }