Skip to content

Commit

Permalink
Merge pull request #2006 from TheBlueMatt/2023-02-no-recursive-read-l…
Browse files Browse the repository at this point in the history
…ocks

Refuse recursive read locks
  • Loading branch information
wpaulino authored Feb 28, 2023
2 parents b8bea74 + 065dc6e commit 8311581
Show file tree
Hide file tree
Showing 13 changed files with 275 additions and 198 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,9 @@ jobs:
cargo test --verbose --color always --features esplora-async
- name: Test backtrace-debug builds on Rust ${{ matrix.toolchain }}
if: "matrix.toolchain == 'stable'"
shell: bash # Default on Winblows is powershell
run: |
cd lightning && cargo test --verbose --color always --features backtrace
cd lightning && RUST_BACKTRACE=1 cargo test --verbose --color always --features backtrace
- name: Test on Rust ${{ matrix.toolchain }} with net-tokio
if: "matrix.build-net-tokio && !matrix.coverage"
run: cargo test --verbose --color always
Expand Down
17 changes: 12 additions & 5 deletions lightning/src/chain/channelmonitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ use core::{cmp, mem};
use crate::io::{self, Error};
use core::convert::TryInto;
use core::ops::Deref;
use crate::sync::Mutex;
use crate::sync::{Mutex, LockTestExt};

/// An update generated by the underlying channel itself which contains some new information the
/// [`ChannelMonitor`] should be made aware of.
Expand Down Expand Up @@ -851,9 +851,13 @@ pub type TransactionOutputs = (Txid, Vec<(u32, TxOut)>);

impl<Signer: WriteableEcdsaChannelSigner> PartialEq for ChannelMonitor<Signer> where Signer: PartialEq {
fn eq(&self, other: &Self) -> bool {
let inner = self.inner.lock().unwrap();
let other = other.inner.lock().unwrap();
inner.eq(&other)
// We need some kind of total lockorder. Absent a better idea, we sort by position in
// memory and take locks in that order (assuming that we can't move within memory while a
// lock is held).
let ord = ((self as *const _) as usize) < ((other as *const _) as usize);
let a = if ord { self.inner.unsafe_well_ordered_double_lock_self() } else { other.inner.unsafe_well_ordered_double_lock_self() };
let b = if ord { other.inner.unsafe_well_ordered_double_lock_self() } else { self.inner.unsafe_well_ordered_double_lock_self() };
a.eq(&b)
}
}

Expand Down Expand Up @@ -4066,7 +4070,10 @@ mod tests {
fn test_prune_preimages() {
let secp_ctx = Secp256k1::new();
let logger = Arc::new(TestLogger::new());
let broadcaster = Arc::new(TestBroadcaster{txn_broadcasted: Mutex::new(Vec::new()), blocks: Arc::new(Mutex::new(Vec::new()))});
let broadcaster = Arc::new(TestBroadcaster {
txn_broadcasted: Mutex::new(Vec::new()),
blocks: Arc::new(Mutex::new(Vec::new()))
});
let fee_estimator = TestFeeEstimator { sat_per_kw: Mutex::new(253) };

let dummy_key = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&[42; 32]).unwrap());
Expand Down
17 changes: 10 additions & 7 deletions lightning/src/ln/chanmon_update_fail_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,13 @@ fn test_monitor_and_persister_update_fail() {
blocks: Arc::new(Mutex::new(vec![(genesis_block(Network::Testnet), 200); 200])),
};
let chain_mon = {
let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap();
let mut w = test_utils::TestVecWriter(Vec::new());
monitor.write(&mut w).unwrap();
let new_monitor = <(BlockHash, ChannelMonitor<EnforcingSigner>)>::read(
&mut io::Cursor::new(&w.0), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1;
assert!(new_monitor == *monitor);
let new_monitor = {
let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap();
let new_monitor = <(BlockHash, ChannelMonitor<EnforcingSigner>)>::read(
&mut io::Cursor::new(&monitor.encode()), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1;
assert!(new_monitor == *monitor);
new_monitor
};
let chain_mon = test_utils::TestChainMonitor::new(Some(&chain_source), &tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister, &node_cfgs[0].keys_manager);
assert_eq!(chain_mon.watch_channel(outpoint, new_monitor), ChannelMonitorUpdateStatus::Completed);
chain_mon
Expand Down Expand Up @@ -1426,9 +1427,11 @@ fn monitor_failed_no_reestablish_response() {
{
let mut node_0_per_peer_lock;
let mut node_0_peer_state_lock;
get_channel_ref!(nodes[0], nodes[1], node_0_per_peer_lock, node_0_peer_state_lock, channel_id).announcement_sigs_state = AnnouncementSigsState::PeerReceived;
}
{
let mut node_1_per_peer_lock;
let mut node_1_peer_state_lock;
get_channel_ref!(nodes[0], nodes[1], node_0_per_peer_lock, node_0_peer_state_lock, channel_id).announcement_sigs_state = AnnouncementSigsState::PeerReceived;
get_channel_ref!(nodes[1], nodes[0], node_1_per_peer_lock, node_1_peer_state_lock, channel_id).announcement_sigs_state = AnnouncementSigsState::PeerReceived;
}

Expand Down
155 changes: 97 additions & 58 deletions lightning/src/ln/channelmanager.rs

Large diffs are not rendered by default.

13 changes: 7 additions & 6 deletions lightning/src/ln/functional_test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ use crate::io;
use crate::prelude::*;
use core::cell::RefCell;
use alloc::rc::Rc;
use crate::sync::{Arc, Mutex};
use crate::sync::{Arc, Mutex, LockTestExt};
use core::mem;
use core::iter::repeat;
use bitcoin::{PackedLockTime, TxMerkleNode};
Expand Down Expand Up @@ -466,8 +466,8 @@ impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> {
panic!();
}
}
assert_eq!(*chain_source.watched_txn.lock().unwrap(), *self.chain_source.watched_txn.lock().unwrap());
assert_eq!(*chain_source.watched_outputs.lock().unwrap(), *self.chain_source.watched_outputs.lock().unwrap());
assert_eq!(*chain_source.watched_txn.unsafe_well_ordered_double_lock_self(), *self.chain_source.watched_txn.unsafe_well_ordered_double_lock_self());
assert_eq!(*chain_source.watched_outputs.unsafe_well_ordered_double_lock_self(), *self.chain_source.watched_outputs.unsafe_well_ordered_double_lock_self());
}
}
}
Expand Down Expand Up @@ -2151,9 +2151,10 @@ pub fn route_over_limit<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_rou
assert!(err.contains("Cannot send value that would put us over the max HTLC value in flight our peer will accept")));
}

pub fn send_payment<'a, 'b, 'c>(origin: &Node<'a, 'b, 'c>, expected_route: &[&Node<'a, 'b, 'c>], recv_value: u64) {
let our_payment_preimage = route_payment(&origin, expected_route, recv_value).0;
claim_payment(&origin, expected_route, our_payment_preimage);
pub fn send_payment<'a, 'b, 'c>(origin: &Node<'a, 'b, 'c>, expected_route: &[&Node<'a, 'b, 'c>], recv_value: u64) -> (PaymentPreimage, PaymentHash, PaymentSecret) {
let res = route_payment(&origin, expected_route, recv_value);
claim_payment(&origin, expected_route, res.0);
res
}

pub fn fail_payment_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_paths: &[&[&Node<'a, 'b, 'c>]], skip_last: bool, our_payment_hash: PaymentHash) {
Expand Down
55 changes: 29 additions & 26 deletions lightning/src/ln/functional_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4083,7 +4083,7 @@ fn do_test_htlc_timeout(send_partial_mpp: bool) {
let cur_height = CHAN_CONFIRM_DEPTH + 1; // route_payment calls send_payment, which adds 1 to the current height. So we do the same here to match.
let payment_id = PaymentId([42; 32]);
let session_privs = nodes[0].node.test_add_new_pending_payment(our_payment_hash, Some(payment_secret), payment_id, &route).unwrap();
nodes[0].node.send_payment_along_path(&route.paths[0], &route.payment_params, &our_payment_hash, &Some(payment_secret), 200_000, cur_height, payment_id, &None, session_privs[0]).unwrap();
nodes[0].node.test_send_payment_along_path(&route.paths[0], &route.payment_params, &our_payment_hash, &Some(payment_secret), 200_000, cur_height, payment_id, &None, session_privs[0]).unwrap();
check_added_monitors!(nodes[0], 1);
let mut events = nodes[0].node.get_and_clear_pending_msg_events();
assert_eq!(events.len(), 1);
Expand Down Expand Up @@ -8150,12 +8150,13 @@ fn test_update_err_monitor_lockdown() {
let logger = test_utils::TestLogger::with_id(format!("node {}", 0));
let persister = test_utils::TestPersister::new();
let watchtower = {
let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap();
let mut w = test_utils::TestVecWriter(Vec::new());
monitor.write(&mut w).unwrap();
let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingSigner>)>::read(
&mut io::Cursor::new(&w.0), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1;
assert!(new_monitor == *monitor);
let new_monitor = {
let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap();
let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingSigner>)>::read(
&mut io::Cursor::new(&monitor.encode()), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1;
assert!(new_monitor == *monitor);
new_monitor
};
let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister, &node_cfgs[0].keys_manager);
assert_eq!(watchtower.watch_channel(outpoint, new_monitor), ChannelMonitorUpdateStatus::Completed);
watchtower
Expand Down Expand Up @@ -8217,12 +8218,13 @@ fn test_concurrent_monitor_claim() {
let logger = test_utils::TestLogger::with_id(format!("node {}", "Alice"));
let persister = test_utils::TestPersister::new();
let watchtower_alice = {
let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap();
let mut w = test_utils::TestVecWriter(Vec::new());
monitor.write(&mut w).unwrap();
let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingSigner>)>::read(
&mut io::Cursor::new(&w.0), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1;
assert!(new_monitor == *monitor);
let new_monitor = {
let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap();
let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingSigner>)>::read(
&mut io::Cursor::new(&monitor.encode()), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1;
assert!(new_monitor == *monitor);
new_monitor
};
let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister, &node_cfgs[0].keys_manager);
assert_eq!(watchtower.watch_channel(outpoint, new_monitor), ChannelMonitorUpdateStatus::Completed);
watchtower
Expand All @@ -8246,12 +8248,13 @@ fn test_concurrent_monitor_claim() {
let logger = test_utils::TestLogger::with_id(format!("node {}", "Bob"));
let persister = test_utils::TestPersister::new();
let watchtower_bob = {
let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap();
let mut w = test_utils::TestVecWriter(Vec::new());
monitor.write(&mut w).unwrap();
let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingSigner>)>::read(
&mut io::Cursor::new(&w.0), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1;
assert!(new_monitor == *monitor);
let new_monitor = {
let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap();
let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingSigner>)>::read(
&mut io::Cursor::new(&monitor.encode()), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1;
assert!(new_monitor == *monitor);
new_monitor
};
let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister, &node_cfgs[0].keys_manager);
assert_eq!(watchtower.watch_channel(outpoint, new_monitor), ChannelMonitorUpdateStatus::Completed);
watchtower
Expand Down Expand Up @@ -9141,20 +9144,20 @@ fn test_inconsistent_mpp_params() {
dup_route.paths.push(route.paths[1].clone());
nodes[0].node.test_add_new_pending_payment(our_payment_hash, Some(our_payment_secret), payment_id, &dup_route).unwrap()
};
{
nodes[0].node.send_payment_along_path(&route.paths[0], &payment_params_opt, &our_payment_hash, &Some(our_payment_secret), 15_000_000, cur_height, payment_id, &None, session_privs[0]).unwrap();
check_added_monitors!(nodes[0], 1);
nodes[0].node.test_send_payment_along_path(&route.paths[0], &payment_params_opt, &our_payment_hash, &Some(our_payment_secret), 15_000_000, cur_height, payment_id, &None, session_privs[0]).unwrap();
check_added_monitors!(nodes[0], 1);

{
let mut events = nodes[0].node.get_and_clear_pending_msg_events();
assert_eq!(events.len(), 1);
pass_along_path(&nodes[0], &[&nodes[1], &nodes[3]], 15_000_000, our_payment_hash, Some(our_payment_secret), events.pop().unwrap(), false, None);
}
assert!(nodes[3].node.get_and_clear_pending_events().is_empty());

{
nodes[0].node.send_payment_along_path(&route.paths[1], &payment_params_opt, &our_payment_hash, &Some(our_payment_secret), 14_000_000, cur_height, payment_id, &None, session_privs[1]).unwrap();
check_added_monitors!(nodes[0], 1);
nodes[0].node.test_send_payment_along_path(&route.paths[1], &payment_params_opt, &our_payment_hash, &Some(our_payment_secret), 14_000_000, cur_height, payment_id, &None, session_privs[1]).unwrap();
check_added_monitors!(nodes[0], 1);

{
let mut events = nodes[0].node.get_and_clear_pending_msg_events();
assert_eq!(events.len(), 1);
let payment_event = SendEvent::from_event(events.pop().unwrap());
Expand Down Expand Up @@ -9197,7 +9200,7 @@ fn test_inconsistent_mpp_params() {

expect_payment_failed_conditions(&nodes[0], our_payment_hash, true, PaymentFailedConditions::new().mpp_parts_remain());

nodes[0].node.send_payment_along_path(&route.paths[1], &payment_params_opt, &our_payment_hash, &Some(our_payment_secret), 15_000_000, cur_height, payment_id, &None, session_privs[2]).unwrap();
nodes[0].node.test_send_payment_along_path(&route.paths[1], &payment_params_opt, &our_payment_hash, &Some(our_payment_secret), 15_000_000, cur_height, payment_id, &None, session_privs[2]).unwrap();
check_added_monitors!(nodes[0], 1);

let mut events = nodes[0].node.get_and_clear_pending_msg_events();
Expand Down
53 changes: 27 additions & 26 deletions lightning/src/ln/payment_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1192,33 +1192,31 @@ fn test_trivial_inflight_htlc_tracking(){
let (_, _, chan_2_id, _) = create_announced_chan_between_nodes(&nodes, 1, 2);

// Send and claim the payment. Inflight HTLCs should be empty.
let (route, payment_hash, payment_preimage, payment_secret) = get_route_and_payment_hash!(nodes[0], nodes[2], 500000);
nodes[0].node.send_payment(&route, payment_hash, &Some(payment_secret), PaymentId(payment_hash.0)).unwrap();
check_added_monitors!(nodes[0], 1);
pass_along_route(&nodes[0], &[&vec!(&nodes[1], &nodes[2])[..]], 500000, payment_hash, payment_secret);
claim_payment(&nodes[0], &vec!(&nodes[1], &nodes[2])[..], payment_preimage);
let payment_hash = send_payment(&nodes[0], &[&nodes[1], &nodes[2]], 500000).1;
let inflight_htlcs = node_chanmgrs[0].compute_inflight_htlcs();
{
let inflight_htlcs = node_chanmgrs[0].compute_inflight_htlcs();

let mut node_0_per_peer_lock;
let mut node_0_peer_state_lock;
let mut node_1_per_peer_lock;
let mut node_1_peer_state_lock;
let channel_1 = get_channel_ref!(&nodes[0], nodes[1], node_0_per_peer_lock, node_0_peer_state_lock, chan_1_id);
let channel_2 = get_channel_ref!(&nodes[1], nodes[2], node_1_per_peer_lock, node_1_peer_state_lock, chan_2_id);

let chan_1_used_liquidity = inflight_htlcs.used_liquidity_msat(
&NodeId::from_pubkey(&nodes[0].node.get_our_node_id()) ,
&NodeId::from_pubkey(&nodes[1].node.get_our_node_id()),
channel_1.get_short_channel_id().unwrap()
);
assert_eq!(chan_1_used_liquidity, None);
}
{
let mut node_1_per_peer_lock;
let mut node_1_peer_state_lock;
let channel_2 = get_channel_ref!(&nodes[1], nodes[2], node_1_per_peer_lock, node_1_peer_state_lock, chan_2_id);

let chan_2_used_liquidity = inflight_htlcs.used_liquidity_msat(
&NodeId::from_pubkey(&nodes[1].node.get_our_node_id()) ,
&NodeId::from_pubkey(&nodes[2].node.get_our_node_id()),
channel_2.get_short_channel_id().unwrap()
);

assert_eq!(chan_1_used_liquidity, None);
assert_eq!(chan_2_used_liquidity, None);
}
let pending_payments = nodes[0].node.list_recent_payments();
Expand All @@ -1231,30 +1229,32 @@ fn test_trivial_inflight_htlc_tracking(){
}

// Send the payment, but do not claim it. Our inflight HTLCs should contain the pending payment.
let (payment_preimage, payment_hash, _) = route_payment(&nodes[0], &vec!(&nodes[1], &nodes[2])[..], 500000);
let (payment_preimage, payment_hash, _) = route_payment(&nodes[0], &[&nodes[1], &nodes[2]], 500000);
let inflight_htlcs = node_chanmgrs[0].compute_inflight_htlcs();
{
let inflight_htlcs = node_chanmgrs[0].compute_inflight_htlcs();

let mut node_0_per_peer_lock;
let mut node_0_peer_state_lock;
let mut node_1_per_peer_lock;
let mut node_1_peer_state_lock;
let channel_1 = get_channel_ref!(&nodes[0], nodes[1], node_0_per_peer_lock, node_0_peer_state_lock, chan_1_id);
let channel_2 = get_channel_ref!(&nodes[1], nodes[2], node_1_per_peer_lock, node_1_peer_state_lock, chan_2_id);

let chan_1_used_liquidity = inflight_htlcs.used_liquidity_msat(
&NodeId::from_pubkey(&nodes[0].node.get_our_node_id()) ,
&NodeId::from_pubkey(&nodes[1].node.get_our_node_id()),
channel_1.get_short_channel_id().unwrap()
);
// First hop accounts for expected 1000 msat fee
assert_eq!(chan_1_used_liquidity, Some(501000));
}
{
let mut node_1_per_peer_lock;
let mut node_1_peer_state_lock;
let channel_2 = get_channel_ref!(&nodes[1], nodes[2], node_1_per_peer_lock, node_1_peer_state_lock, chan_2_id);

let chan_2_used_liquidity = inflight_htlcs.used_liquidity_msat(
&NodeId::from_pubkey(&nodes[1].node.get_our_node_id()) ,
&NodeId::from_pubkey(&nodes[2].node.get_our_node_id()),
channel_2.get_short_channel_id().unwrap()
);

// First hop accounts for expected 1000 msat fee
assert_eq!(chan_1_used_liquidity, Some(501000));
assert_eq!(chan_2_used_liquidity, Some(500000));
}
let pending_payments = nodes[0].node.list_recent_payments();
Expand All @@ -1269,28 +1269,29 @@ fn test_trivial_inflight_htlc_tracking(){
nodes[0].node.timer_tick_occurred();
}

let inflight_htlcs = node_chanmgrs[0].compute_inflight_htlcs();
{
let inflight_htlcs = node_chanmgrs[0].compute_inflight_htlcs();

let mut node_0_per_peer_lock;
let mut node_0_peer_state_lock;
let mut node_1_per_peer_lock;
let mut node_1_peer_state_lock;
let channel_1 = get_channel_ref!(&nodes[0], nodes[1], node_0_per_peer_lock, node_0_peer_state_lock, chan_1_id);
let channel_2 = get_channel_ref!(&nodes[1], nodes[2], node_1_per_peer_lock, node_1_peer_state_lock, chan_2_id);

let chan_1_used_liquidity = inflight_htlcs.used_liquidity_msat(
&NodeId::from_pubkey(&nodes[0].node.get_our_node_id()) ,
&NodeId::from_pubkey(&nodes[1].node.get_our_node_id()),
channel_1.get_short_channel_id().unwrap()
);
assert_eq!(chan_1_used_liquidity, None);
}
{
let mut node_1_per_peer_lock;
let mut node_1_peer_state_lock;
let channel_2 = get_channel_ref!(&nodes[1], nodes[2], node_1_per_peer_lock, node_1_peer_state_lock, chan_2_id);

let chan_2_used_liquidity = inflight_htlcs.used_liquidity_msat(
&NodeId::from_pubkey(&nodes[1].node.get_our_node_id()) ,
&NodeId::from_pubkey(&nodes[2].node.get_our_node_id()),
channel_2.get_short_channel_id().unwrap()
);

assert_eq!(chan_1_used_liquidity, None);
assert_eq!(chan_2_used_liquidity, None);
}

Expand Down
Loading

0 comments on commit 8311581

Please sign in to comment.