Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace BTreeSet in IndexedMap with sorted Vec #2040

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 39 additions & 15 deletions fuzz/src/indexedmap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,27 @@ use hashbrown::HashSet;

use crate::utils::test_logger;

fn check_eq(btree: &BTreeMap<u8, u8>, indexed: &IndexedMap<u8, u8>) {
use std::ops::{RangeBounds, Bound};

struct ExclLowerInclUpper(u8, u8);
impl RangeBounds<u8> for ExclLowerInclUpper {
fn start_bound(&self) -> Bound<&u8> { Bound::Excluded(&self.0) }
fn end_bound(&self) -> Bound<&u8> { Bound::Included(&self.1) }
}
struct ExclLowerExclUpper(u8, u8);
impl RangeBounds<u8> for ExclLowerExclUpper {
fn start_bound(&self) -> Bound<&u8> { Bound::Excluded(&self.0) }
fn end_bound(&self) -> Bound<&u8> { Bound::Excluded(&self.1) }
}

fn check_eq(btree: &BTreeMap<u8, u8>, mut indexed: IndexedMap<u8, u8>) {
assert_eq!(btree.len(), indexed.len());
assert_eq!(btree.is_empty(), indexed.is_empty());

let mut btree_clone = btree.clone();
assert!(btree_clone == *btree);
let mut indexed_clone = indexed.clone();
assert!(indexed_clone == *indexed);
assert!(indexed_clone == indexed);

for k in 0..=255 {
assert_eq!(btree.contains_key(&k), indexed.contains_key(&k));
Expand All @@ -43,16 +56,27 @@ fn check_eq(btree: &BTreeMap<u8, u8>, indexed: &IndexedMap<u8, u8>) {
}

const STRIDE: u8 = 16;
for k in 0..=255/STRIDE {
let lower_bound = k * STRIDE;
let upper_bound = lower_bound + (STRIDE - 1);
let mut btree_iter = btree.range(lower_bound..=upper_bound);
let mut indexed_iter = indexed.range(lower_bound..=upper_bound);
loop {
let b_v = btree_iter.next();
let i_v = indexed_iter.next();
assert_eq!(b_v, i_v);
if b_v.is_none() { break; }
for range_type in 0..4 {
for k in 0..=255/STRIDE {
let lower_bound = k * STRIDE;
let upper_bound = lower_bound + (STRIDE - 1);
macro_rules! range { ($map: expr) => {
match range_type {
0 => $map.range(lower_bound..upper_bound),
1 => $map.range(lower_bound..=upper_bound),
2 => $map.range(ExclLowerInclUpper(lower_bound, upper_bound)),
3 => $map.range(ExclLowerExclUpper(lower_bound, upper_bound)),
_ => unreachable!(),
}
} }
let mut btree_iter = range!(btree);
let mut indexed_iter = range!(indexed);
loop {
let b_v = btree_iter.next();
let i_v = indexed_iter.next();
assert_eq!(b_v, i_v);
if b_v.is_none() { break; }
}
}
}

Expand Down Expand Up @@ -91,15 +115,15 @@ pub fn do_test(data: &[u8]) {
let prev_value_i = indexed.insert(tuple[0], tuple[1]);
assert_eq!(prev_value_b, prev_value_i);
}
check_eq(&btree, &indexed);
check_eq(&btree, indexed.clone());

// Now, modify the maps in all the ways we have to do so, checking that the maps remain
// equivalent as we go.
for (k, v) in indexed.unordered_iter_mut() {
*v = *k;
*btree.get_mut(k).unwrap() = *k;
}
check_eq(&btree, &indexed);
check_eq(&btree, indexed.clone());

for k in 0..=255 {
match btree.entry(k) {
Expand All @@ -124,7 +148,7 @@ pub fn do_test(data: &[u8]) {
},
}
}
check_eq(&btree, &indexed);
check_eq(&btree, indexed);
}

pub fn indexedmap_test<Out: test_logger::Output>(data: &[u8], _out: Out) {
Expand Down
6 changes: 3 additions & 3 deletions lightning/src/routing/gossip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ where U::Target: UtxoLookup, L::Target: Logger
}

fn get_next_channel_announcement(&self, starting_point: u64) -> Option<(ChannelAnnouncement, Option<ChannelUpdate>, Option<ChannelUpdate>)> {
let channels = self.network_graph.channels.read().unwrap();
let mut channels = self.network_graph.channels.write().unwrap();
for (_, ref chan) in channels.range(starting_point..) {
if chan.announcement_message.is_some() {
let chan_announcement = chan.announcement_message.clone().unwrap();
Expand All @@ -412,7 +412,7 @@ where U::Target: UtxoLookup, L::Target: Logger
}

fn get_next_node_announcement(&self, starting_point: Option<&NodeId>) -> Option<NodeAnnouncement> {
let nodes = self.network_graph.nodes.read().unwrap();
let mut nodes = self.network_graph.nodes.write().unwrap();
let iter = if let Some(node_id) = starting_point {
nodes.range((Bound::Excluded(node_id), Bound::Unbounded))
} else {
Expand Down Expand Up @@ -572,7 +572,7 @@ where U::Target: UtxoLookup, L::Target: Logger
// (has at least one update). A peer may still want to know the channel
// exists even if its not yet routable.
let mut batches: Vec<Vec<u64>> = vec![Vec::with_capacity(MAX_SCIDS_PER_REPLY)];
let channels = self.network_graph.channels.read().unwrap();
let mut channels = self.network_graph.channels.write().unwrap();
for (_, ref chan) in channels.range(inclusive_start_scid.unwrap()..exclusive_end_scid.unwrap()) {
if let Some(chan_announcement) = &chan.announcement_message {
// Construct a new batch if last one is full
Expand Down
50 changes: 35 additions & 15 deletions lightning/src/util/indexed_map.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
//! This module has a map which can be iterated in a deterministic order. See the [`IndexedMap`].

use crate::prelude::{HashMap, hash_map};
use alloc::collections::{BTreeSet, btree_set};
use alloc::vec::Vec;
use alloc::slice::Iter;
use core::hash::Hash;
use core::cmp::Ord;
use core::ops::RangeBounds;
use core::ops::{Bound, RangeBounds};

/// A map which can be iterated in a deterministic order.
///
Expand All @@ -21,19 +22,18 @@ use core::ops::RangeBounds;
/// keys in the order defined by [`Ord`].
///
/// [`BTreeMap`]: alloc::collections::BTreeMap
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, Eq)]
pub struct IndexedMap<K: Hash + Ord, V> {
map: HashMap<K, V>,
// TODO: Explore swapping this for a sorted vec (that is only sorted on first range() call)
keys: BTreeSet<K>,
keys: Vec<K>,
}

impl<K: Clone + Hash + Ord, V> IndexedMap<K, V> {
/// Constructs a new, empty map
pub fn new() -> Self {
Self {
map: HashMap::new(),
keys: BTreeSet::new(),
keys: Vec::new(),
}
}

Expand All @@ -58,7 +58,8 @@ impl<K: Clone + Hash + Ord, V> IndexedMap<K, V> {
pub fn remove(&mut self, key: &K) -> Option<V> {
let ret = self.map.remove(key);
if let Some(_) = ret {
assert!(self.keys.remove(key), "map and keys must be consistent");
let idx = self.keys.iter().position(|k| k == key).expect("map and keys must be consistent");
self.keys.remove(idx);
}
ret
}
Expand All @@ -68,7 +69,7 @@ impl<K: Clone + Hash + Ord, V> IndexedMap<K, V> {
pub fn insert(&mut self, key: K, value: V) -> Option<V> {
let ret = self.map.insert(key.clone(), value);
if ret.is_none() {
assert!(self.keys.insert(key), "map and keys must be consistent");
self.keys.push(key);
}
ret
}
Expand Down Expand Up @@ -109,9 +110,21 @@ impl<K: Clone + Hash + Ord, V> IndexedMap<K, V> {
}

/// Returns an iterator which iterates over the `key`/`value` pairs in a given range.
pub fn range<R: RangeBounds<K>>(&self, range: R) -> Range<K, V> {
pub fn range<R: RangeBounds<K>>(&mut self, range: R) -> Range<K, V> {
self.keys.sort_unstable();
let start = match range.start_bound() {
Bound::Unbounded => 0,
Bound::Included(key) => self.keys.binary_search(key).unwrap_or_else(|index| index),
alecchendev marked this conversation as resolved.
Show resolved Hide resolved
Bound::Excluded(key) => self.keys.binary_search(key).and_then(|index| Ok(index + 1)).unwrap_or_else(|index| index),
};
let end = match range.end_bound() {
TheBlueMatt marked this conversation as resolved.
Show resolved Hide resolved
Bound::Unbounded => self.keys.len(),
Bound::Included(key) => self.keys.binary_search(key).and_then(|index| Ok(index + 1)).unwrap_or_else(|index| index),
Bound::Excluded(key) => self.keys.binary_search(key).unwrap_or_else(|index| index),
};

Range {
inner_range: self.keys.range(range),
inner_range: self.keys[start..end].iter(),
map: &self.map,
}
}
Expand All @@ -127,9 +140,15 @@ impl<K: Clone + Hash + Ord, V> IndexedMap<K, V> {
}
}

impl<K: Hash + Ord + PartialEq, V: PartialEq> PartialEq for IndexedMap<K, V> {
fn eq(&self, other: &Self) -> bool {
self.map == other.map
}
}

/// An iterator over a range of values in an [`IndexedMap`]
pub struct Range<'a, K: Hash + Ord, V> {
inner_range: btree_set::Range<'a, K>,
inner_range: Iter<'a, K>,
map: &'a HashMap<K, V>,
}
impl<'a, K: Hash + Ord, V: 'a> Iterator for Range<'a, K, V> {
Expand All @@ -148,7 +167,7 @@ pub struct VacantEntry<'a, K: Hash + Ord, V> {
#[cfg(not(feature = "hashbrown"))]
underlying_entry: hash_map::VacantEntry<'a, K, V>,
key: K,
keys: &'a mut BTreeSet<K>,
keys: &'a mut Vec<K>,
}

/// An [`Entry`] for an existing key-value pair
Expand All @@ -157,7 +176,7 @@ pub struct OccupiedEntry<'a, K: Hash + Ord, V> {
underlying_entry: hash_map::OccupiedEntry<'a, K, V, hash_map::DefaultHashBuilder>,
#[cfg(not(feature = "hashbrown"))]
underlying_entry: hash_map::OccupiedEntry<'a, K, V>,
keys: &'a mut BTreeSet<K>,
keys: &'a mut Vec<K>,
}

/// A mutable reference to a position in the map. This can be used to reference, add, or update the
Expand All @@ -172,7 +191,7 @@ pub enum Entry<'a, K: Hash + Ord, V> {
impl<'a, K: Hash + Ord, V> VacantEntry<'a, K, V> {
/// Insert a value into the position described by this entry.
pub fn insert(self, value: V) -> &'a mut V {
assert!(self.keys.insert(self.key), "map and keys must be consistent");
self.keys.push(self.key);
self.underlying_entry.insert(value)
}
}
Expand All @@ -181,7 +200,8 @@ impl<'a, K: Hash + Ord, V> OccupiedEntry<'a, K, V> {
/// Remove the value at the position described by this entry.
pub fn remove_entry(self) -> (K, V) {
let res = self.underlying_entry.remove_entry();
assert!(self.keys.remove(&res.0), "map and keys must be consistent");
let idx = self.keys.iter().position(|k| k == &res.0).expect("map and keys must be consistent");
self.keys.remove(idx);
res
}

Expand Down