Skip to content

Commit

Permalink
Fix race between outbound messages and peer disconnection
Browse files Browse the repository at this point in the history
Previously, outbound messages held in `process_events` could race
with peer disconnection, allowing a message intended for a peer
before disconnection to be sent to the same peer after
disconnection.

The fix is simple - hold the peers read lock while we fetch
pending messages from peers (as we disconnect with the write lock).
  • Loading branch information
TheBlueMatt committed Oct 14, 2023
1 parent e7690dd commit f1e97f9
Showing 1 changed file with 139 additions and 21 deletions.
160 changes: 139 additions & 21 deletions lightning/src/ln/peer_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1891,15 +1891,12 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
let flush_read_disabled = self.gossip_processing_backlog_lifted.swap(false, Ordering::Relaxed);

let mut peers_to_disconnect = HashMap::new();
let mut events_generated = self.message_handler.chan_handler.get_and_clear_pending_msg_events();
events_generated.append(&mut self.message_handler.route_handler.get_and_clear_pending_msg_events());

{
// TODO: There are some DoS attacks here where you can flood someone's outbound send
// buffer by doing things like announcing channels on another node. We should be willing to
// drop optional-ish messages when send buffers get full!

let peers_lock = self.peers.read().unwrap();

let mut events_generated = self.message_handler.chan_handler.get_and_clear_pending_msg_events();
events_generated.append(&mut self.message_handler.route_handler.get_and_clear_pending_msg_events());

let peers = &*peers_lock;
macro_rules! get_peer_for_forwarding {
($node_id: expr) => {
Expand Down Expand Up @@ -2520,12 +2517,11 @@ mod tests {

use crate::prelude::*;
use crate::sync::{Arc, Mutex};
use core::convert::Infallible;
use core::sync::atomic::{AtomicBool, Ordering};
use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};

#[derive(Clone)]
struct FileDescriptor {
fd: u16,
fd: u32,
outbound_data: Arc<Mutex<Vec<u8>>>,
disconnect: Arc<AtomicBool>,
}
Expand Down Expand Up @@ -2560,24 +2556,44 @@ mod tests {

struct TestCustomMessageHandler {
features: InitFeatures,
peer_counter: AtomicUsize,
send_messages: Option<PublicKey>,
}

impl crate::ln::wire::Type for u64 {
fn type_id(&self) -> u16 { 4242 }
}

impl wire::CustomMessageReader for TestCustomMessageHandler {
type CustomMessage = Infallible;
fn read<R: io::Read>(&self, _: u16, _: &mut R) -> Result<Option<Self::CustomMessage>, msgs::DecodeError> {
Ok(None)
type CustomMessage = u64;
fn read<R: io::Read>(&self, msg_type: u16, reader: &mut R) -> Result<Option<Self::CustomMessage>, msgs::DecodeError> {
assert!(self.send_messages.is_some());
assert_eq!(msg_type, 4242);
let mut msg = [0u8; 8];
reader.read_exact(&mut msg).unwrap();
Ok(Some(u64::from_be_bytes(msg)))
}
}

impl CustomMessageHandler for TestCustomMessageHandler {
fn handle_custom_message(&self, _: Infallible, _: &PublicKey) -> Result<(), LightningError> {
unreachable!();
fn handle_custom_message(&self, msg: u64, _: &PublicKey) -> Result<(), LightningError> {
assert_eq!(self.peer_counter.load(Ordering::Acquire) as u64, msg);
Ok(())
}

fn get_and_clear_pending_msg(&self) -> Vec<(PublicKey, Self::CustomMessage)> { Vec::new() }
fn get_and_clear_pending_msg(&self) -> Vec<(PublicKey, Self::CustomMessage)> {
if let Some(peer_node_id) = &self.send_messages {
vec![(*peer_node_id, self.peer_counter.load(Ordering::Acquire) as u64); 1000]
} else { Vec::new() }
}

fn peer_disconnected(&self, _: &PublicKey) {}
fn peer_connected(&self, _: &PublicKey, _: &msgs::Init, _: bool) -> Result<(), ()> { Ok(()) }
fn peer_disconnected(&self, _: &PublicKey) {
self.peer_counter.fetch_sub(1, Ordering::AcqRel);
}
fn peer_connected(&self, _: &PublicKey, _: &msgs::Init, _: bool) -> Result<(), ()> {
self.peer_counter.fetch_add(2, Ordering::AcqRel);
Ok(())
}

fn provided_node_features(&self) -> NodeFeatures { NodeFeatures::empty() }

Expand All @@ -2600,7 +2616,9 @@ mod tests {
chan_handler: test_utils::TestChannelMessageHandler::new(ChainHash::using_genesis_block(Network::Testnet)),
logger: test_utils::TestLogger::new(),
routing_handler: test_utils::TestRoutingMessageHandler::new(),
custom_handler: TestCustomMessageHandler { features },
custom_handler: TestCustomMessageHandler {
features, peer_counter: AtomicUsize::new(0), send_messages: None,
},
node_signer: test_utils::TestNodeSigner::new(node_secret),
}
);
Expand All @@ -2623,7 +2641,9 @@ mod tests {
chan_handler: test_utils::TestChannelMessageHandler::new(ChainHash::using_genesis_block(Network::Testnet)),
logger: test_utils::TestLogger::new(),
routing_handler: test_utils::TestRoutingMessageHandler::new(),
custom_handler: TestCustomMessageHandler { features },
custom_handler: TestCustomMessageHandler {
features, peer_counter: AtomicUsize::new(0), send_messages: None,
},
node_signer: test_utils::TestNodeSigner::new(node_secret),
}
);
Expand All @@ -2643,7 +2663,9 @@ mod tests {
chan_handler: test_utils::TestChannelMessageHandler::new(network),
logger: test_utils::TestLogger::new(),
routing_handler: test_utils::TestRoutingMessageHandler::new(),
custom_handler: TestCustomMessageHandler { features },
custom_handler: TestCustomMessageHandler {
features, peer_counter: AtomicUsize::new(0), send_messages: None,
},
node_signer: test_utils::TestNodeSigner::new(node_secret),
}
);
Expand Down Expand Up @@ -3191,4 +3213,100 @@ mod tests {
thread_c.join().unwrap();
assert!(cfg[0].chan_handler.message_fetch_counter.load(Ordering::Acquire) >= 1);
}

#[test]
#[cfg(feature = "std")]
fn test_rapid_connect_events_order_multithreaded() {
// Previously, outbound messages held in `process_events` could race with peer
// disconnection, allowing a message intended for a peer before disconnection to be sent
// to the same peer after disconnection. Here we stress the handling of such messages by
// connecting two peers repeatedly in a loop with a `CustomMessageHandler` set to stream
// custom messages with a "connection id" to each other. That "connection id" (just the
// number of reconnections seen) should always line up across both peers, which we assert
// in the message handler.
let mut cfg = create_peermgr_cfgs(2);
cfg[0].custom_handler.send_messages =
Some(cfg[1].node_signer.get_node_id(Recipient::Node).unwrap());
cfg[1].custom_handler.send_messages =
Some(cfg[1].node_signer.get_node_id(Recipient::Node).unwrap());
let cfg = Arc::new(cfg);
// Until we have std::thread::scoped we have to unsafe { turn off the borrow checker }.
let mut peers = create_network(2, unsafe { &*(&*cfg as *const _) as &'static _ });
let peer_a = Arc::new(peers.pop().unwrap());
let peer_b = Arc::new(peers.pop().unwrap());

let exit_flag = Arc::new(AtomicBool::new(false));
macro_rules! spawn_thread { ($id: expr) => { {
let thread_peer_a = Arc::clone(&peer_a);
let thread_peer_b = Arc::clone(&peer_b);
let thread_exit = Arc::clone(&exit_flag);
std::thread::spawn(move || {
let id_a = thread_peer_a.node_signer.get_node_id(Recipient::Node).unwrap();
let mut fd_a = FileDescriptor {
fd: $id, outbound_data: Arc::new(Mutex::new(Vec::new())),
disconnect: Arc::new(AtomicBool::new(false)),
};
let addr_a = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1000};
let mut fd_b = FileDescriptor {
fd: $id, outbound_data: Arc::new(Mutex::new(Vec::new())),
disconnect: Arc::new(AtomicBool::new(false)),
};
let addr_b = SocketAddress::TcpIpV4{addr: [127, 0, 0, 1], port: 1001};
let initial_data = thread_peer_b.new_outbound_connection(id_a, fd_b.clone(), Some(addr_a.clone())).unwrap();
thread_peer_a.new_inbound_connection(fd_a.clone(), Some(addr_b.clone())).unwrap();
if thread_peer_a.read_event(&mut fd_a, &initial_data).is_err() {
thread_peer_b.socket_disconnected(&fd_b);
return;
}

loop {
if thread_exit.load(Ordering::Relaxed) {
thread_peer_a.socket_disconnected(&fd_a);
thread_peer_b.socket_disconnected(&fd_b);
return;
}
if fd_a.disconnect.load(Ordering::Relaxed) { return; }
if fd_b.disconnect.load(Ordering::Relaxed) { return; }

let data_a = fd_a.outbound_data.lock().unwrap().split_off(0);
if !data_a.is_empty() {
if thread_peer_b.read_event(&mut fd_b, &data_a).is_err() {
thread_peer_a.socket_disconnected(&fd_a);
return;
}
}

let data_b = fd_b.outbound_data.lock().unwrap().split_off(0);
if !data_b.is_empty() {
if thread_peer_a.read_event(&mut fd_a, &data_b).is_err() {
thread_peer_b.socket_disconnected(&fd_b);
return;
}
}
}
})
} } }

let mut threads = Vec::new();
{
let thread_peer_a = Arc::clone(&peer_a);
let thread_peer_b = Arc::clone(&peer_b);
let thread_exit = Arc::clone(&exit_flag);
threads.push(std::thread::spawn(move || {
while !thread_exit.load(Ordering::Relaxed) {
thread_peer_a.process_events();
thread_peer_b.process_events();
}
}));
}
for i in 0..1000 {
threads.push(spawn_thread!(i));
}
exit_flag.store(true, Ordering::Relaxed);
for thread in threads {
thread.join().unwrap();
}
assert_eq!(peer_a.peers.read().unwrap().len(), 0);
assert_eq!(peer_b.peers.read().unwrap().len(), 0);
}
}

0 comments on commit f1e97f9

Please sign in to comment.